Basic Layers
Index
GraphNeuralNetworks.DotDecoder
GraphNeuralNetworks.GNNChain
GraphNeuralNetworks.GNNLayer
GraphNeuralNetworks.WithGraph
Docs
GraphNeuralNetworks.DotDecoder
— TypeDotDecoder()
A graph neural network layer that for given input graph g
and node features x
, returns the dot product x_i ⋅ xj
on each edge.
Examples
julia> g = rand_graph(5, 6)
GNNGraph:
num_nodes = 5
num_edges = 6
julia> dotdec = DotDecoder()
DotDecoder()
julia> dotdec(g, rand(2, 5))
1×6 Matrix{Float64}:
0.345098 0.458305 0.106353 0.345098 0.458305 0.106353
GraphNeuralNetworks.GNNChain
— TypeGNNChain(layers...)
GNNChain(name = layer, ...)
Collects multiple layers / functions to be called in sequence on given input graph and input node features.
It allows to compose layers in a sequential fashion as Flux.Chain
does, propagating the output of each layer to the next one. In addition, GNNChain
handles the input graph as well, providing it as a first argument only to layers subtyping the GNNLayer
abstract type.
GNNChain
supports indexing and slicing, m[2]
or m[1:end-1]
, and if names are given, m[:name] == m[1]
etc.
Examples
julia> using Flux, GraphNeuralNetworks
julia> m = GNNChain(GCNConv(2=>5),
BatchNorm(5),
x -> relu.(x),
Dense(5, 4))
GNNChain(GCNConv(2 => 5), BatchNorm(5), #7, Dense(5 => 4))
julia> x = randn(Float32, 2, 3);
julia> g = rand_graph(3, 6)
GNNGraph:
num_nodes = 3
num_edges = 6
julia> m(g, x)
4×3 Matrix{Float32}:
-0.795592 -0.795592 -0.795592
-0.736409 -0.736409 -0.736409
0.994925 0.994925 0.994925
0.857549 0.857549 0.857549
julia> m2 = GNNChain(enc = m,
dec = DotDecoder())
GNNChain(enc = GNNChain(GCNConv(2 => 5), BatchNorm(5), #7, Dense(5 => 4)), dec = DotDecoder())
julia> m2(g, x)
1×6 Matrix{Float32}:
2.90053 2.90053 2.90053 2.90053 2.90053 2.90053
julia> m2[:enc](g, x) == m(g, x)
true
GraphNeuralNetworks.GNNLayer
— Typeabstract type GNNLayer end
An abstract type from which graph neural network layers are derived.
See also GNNChain
.
GraphNeuralNetworks.WithGraph
— TypeWithGraph(model, g::GNNGraph; traingraph=false)
A type wrapping the model
and tying it to the graph g
. In the forward pass, can only take feature arrays as inputs, returning model(g, x...; kws...)
.
If traingraph=false
, the graph's parameters won't be part of the trainable
parameters in the gradient updates.
Examples
g = GNNGraph([1,2,3], [2,3,1])
x = rand(Float32, 2, 3)
model = SAGEConv(2 => 3)
wg = WithGraph(model, g)
# No need to feed the graph to `wg`
@assert wg(x) == model(g, x)
g2 = GNNGraph([1,1,2,3], [2,4,1,1])
x2 = rand(Float32, 2, 4)
# WithGraph will ignore the internal graph if fed with a new one.
@assert wg(g2, x2) == model(g2, x2)