Basic Layers

GNNLux.GNNLayerType
abstract type GNNLayer <: AbstractLuxLayer end

An abstract type from which graph neural network layers are derived. It is derived from Lux's AbstractLuxLayer type.

See also GNNLux.GNNChain.

source
GNNLux.GNNChainType
GNNChain(layers...)
GNNChain(name = layer, ...)

Collects multiple layers / functions to be called in sequence on given input graph and input node features.

It allows to compose layers in a sequential fashion as Lux.Chain does, propagating the output of each layer to the next one. In addition, GNNChain handles the input graph as well, providing it as a first argument only to layers subtyping the GNNLayer abstract type.

GNNChain supports indexing and slicing, m[2] or m[1:end-1], and if names are given, m[:name] == m[1] etc.

Examples

julia> using Lux, GNNLux, Random

julia> rng = Random.default_rng();

julia> m = GNNChain(GCNConv(2 => 5, relu), Dense(5 => 4))
GNNChain(
    layers = NamedTuple(
        layer_1 = GCNConv(2 => 5, relu),  # 15 parameters
        layer_2 = Dense(5 => 4),        # 24 parameters
    ),
)         # Total: 39 parameters,
          #        plus 0 states.

julia> x = randn(rng, Float32, 2, 3);

julia> g = rand_graph(rng, 3, 6)
GNNGraph:
  num_nodes: 3
  num_edges: 6

julia> ps, st = LuxCore.setup(rng, m);

julia> y, st = m(g, x, ps, st);     # First entry is the output, second entry is the state of the model

julia> size(y)
(4, 3)
source