Basic Layers
GraphNeuralNetworks.DotDecoder
— TypeDotDecoder()
A graph neural network layer that for given input graph g
and node features x
, returns the dot product x_i ⋅ xj
on each edge.
Examples
julia> g = rand_graph(5, 6)
GNNGraph:
num_nodes: 5
num_edges: 6
julia> dotdec = DotDecoder()
DotDecoder()
julia> dotdec(g, rand(2, 5)) |> size
(1, 6)
GraphNeuralNetworks.GNNChain
— TypeGNNChain(layers...)
GNNChain(name = layer, ...)
Collects multiple layers / functions to be called in sequence on given input graph and input node features.
It allows to compose layers in a sequential fashion as Flux.Chain
does, propagating the output of each layer to the next one. In addition, GNNChain
handles the input graph as well, providing it as a first argument only to layers subtyping the GNNLayer
abstract type.
GNNChain
supports indexing and slicing, m[2]
or m[1:end-1]
, and if names are given, m[:name] == m[1]
etc.
Examples
julia> using Flux, GraphNeuralNetworks
julia> m = GNNChain(GCNConv(2=>5),
BatchNorm(5),
x -> relu.(x),
Dense(5, 4));
julia> x = randn(Float32, 2, 3);
julia> g = rand_graph(3, 6)
GNNGraph:
num_nodes: 3
num_edges: 6
julia> m(g, x) |> size
(4, 3)
julia> m2 = GNNChain(enc = m, dec = DotDecoder());
julia> m2(g, x) |> size
(1, 6)
julia> m2[:enc](g, x) == m(g, x)
true
GraphNeuralNetworks.GNNLayer
— Typeabstract type GNNLayer end
An abstract type from which graph neural network layers are derived.
See also GNNChain
.
GraphNeuralNetworks.WithGraph
— TypeWithGraph(model, g::GNNGraph; traingraph=false)
A type wrapping the model
and tying it to the graph g
. In the forward pass, can only take feature arrays as inputs, returning model(g, x...; kws...)
.
If traingraph=false
, the graph's parameters won't be part of the trainable
parameters in the gradient updates.
Examples
g = GNNGraph([1,2,3], [2,3,1])
x = rand(Float32, 2, 3)
model = SAGEConv(2 => 3)
wg = WithGraph(model, g)
# No need to feed the graph to `wg`
@assert wg(x) == model(g, x)
g2 = GNNGraph([1,1,2,3], [2,4,1,1])
x2 = rand(Float32, 2, 4)
# WithGraph will ignore the internal graph if fed with a new one.
@assert wg(g2, x2) == model(g2, x2)