Message Passing
Index
GNNlib.aggregate_neighbors
GNNlib.apply_edges
GNNlib.copy_xi
GNNlib.copy_xj
GNNlib.e_mul_xj
GNNlib.propagate
GNNlib.w_mul_xj
GNNlib.xi_dot_xj
GNNlib.xi_sub_xj
GNNlib.xj_sub_xi
Interface
GNNlib.apply_edges
— Functionapply_edges(fmsg, g; [xi, xj, e])
apply_edges(fmsg, g, xi, xj, e=nothing)
Returns the message from node j
to node i
applying the message function fmsg
on the edges in graph g
. In the message-passing scheme, the incoming messages from the neighborhood of i
will later be aggregated in order to update the features of node i
(see aggregate_neighbors
).
The function fmsg
operates on batches of edges, therefore xi
, xj
, and e
are tensors whose last dimension is the batch size, or can be named tuples of such tensors.
Arguments
g
: AnAbstractGNNGraph
.xi
: An array or a named tuple containing arrays whose last dimension's size isg.num_nodes
. It will be appropriately materialized on the target node of each edge (see alsoedge_index
).xj
: Asxi
, but now to be materialized on each edge's source node.e
: An array or a named tuple containing arrays whose last dimension's size isg.num_edges
.fmsg
: A function that takes as inputs the edge-materializedxi
,xj
, ande
. These are arrays (or named tuples of arrays) whose last dimension' size is the size of a batch of edges. The output off
has to be an array (or a named tuple of arrays) with the same batch size. If alsolayer
is passed to propagate, the signature offmsg
has to befmsg(layer, xi, xj, e)
instead offmsg(xi, xj, e)
.
See also propagate
and aggregate_neighbors
.
GNNlib.aggregate_neighbors
— Functionaggregate_neighbors(g, aggr, m)
Given a graph g
, edge features m
, and an aggregation operator aggr
(e.g +, min, max, mean
), returns the new node features
\[\mathbf{x}_i = \square_{j \in \mathcal{N}(i)} \mathbf{m}_{j\to i}\]
Neighborhood aggregation is the second step of propagate
, where it comes after apply_edges
.
GNNlib.propagate
— Functionpropagate(fmsg, g, aggr; [xi, xj, e])
propagate(fmsg, g, aggr xi, xj, e=nothing)
Performs message passing on graph g
. Takes care of materializing the node features on each edge, applying the message function fmsg
, and returning an aggregated message $\bar{\mathbf{m}}$ (depending on the return value of fmsg
, an array or a named tuple of arrays with last dimension's size g.num_nodes
).
It can be decomposed in two steps:
m = apply_edges(fmsg, g, xi, xj, e)
m̄ = aggregate_neighbors(g, aggr, m)
GNN layers typically call propagate
in their forward pass, providing as input f
a closure.
Arguments
g
: AGNNGraph
.xi
: An array or a named tuple containing arrays whose last dimension's size isg.num_nodes
. It will be appropriately materialized on the target node of each edge (see alsoedge_index
).xj
: Asxj
, but to be materialized on edges' sources.e
: An array or a named tuple containing arrays whose last dimension's size isg.num_edges
.fmsg
: A generic function that will be passed over toapply_edges
. Has to take as inputs the edge-materializedxi
,xj
, ande
(arrays or named tuples of arrays whose last dimension' size is the size of a batch of edges). Its output has to be an array or a named tuple of arrays with the same batch size. If alsolayer
is passed to propagate, the signature offmsg
has to befmsg(layer, xi, xj, e)
instead offmsg(xi, xj, e)
.aggr
: Neighborhood aggregation operator. Use+
,mean
,max
, ormin
.
Examples
using GraphNeuralNetworks, Flux
struct GNNConv <: GNNLayer
W
b
σ
end
Flux.@layer GNNConv
function GNNConv(ch::Pair{Int,Int}, σ=identity)
in, out = ch
W = Flux.glorot_uniform(out, in)
b = zeros(Float32, out)
GNNConv(W, b, σ)
end
function (l::GNNConv)(g::GNNGraph, x::AbstractMatrix)
message(xi, xj, e) = l.W * xj
m̄ = propagate(message, g, +, xj=x)
return l.σ.(m̄ .+ l.bias)
end
l = GNNConv(10 => 20)
l(g, x)
See also apply_edges
and aggregate_neighbors
.
Built-in message functions
GNNlib.copy_xi
— Functioncopy_xi(xi, xj, e) = xi
GNNlib.copy_xj
— Functioncopy_xj(xi, xj, e) = xj
GNNlib.xi_dot_xj
— Functionxi_dot_xj(xi, xj, e) = sum(xi .* xj, dims=1)
GNNlib.xi_sub_xj
— Functionxi_sub_xj(xi, xj, e) = xi .- xj
GNNlib.xj_sub_xi
— Functionxj_sub_xi(xi, xj, e) = xj .- xi
GNNlib.e_mul_xj
— Functione_mul_xj(xi, xj, e) = reshape(e, (...)) .* xj
Reshape e
into broadcast compatible shape with xj
(by prepending singleton dimensions) then perform broadcasted multiplication.
GNNlib.w_mul_xj
— Functionw_mul_xj(xi, xj, w) = reshape(w, (...)) .* xj
Similar to e_mul_xj
but specialized on scalar edge features (weights).