Message Passing
Interface
GNNlib.apply_edges — Functionapply_edges(fmsg, g; [xi, xj, e])
apply_edges(fmsg, g, xi, xj, e=nothing)Returns the message from node j to node i applying the message function fmsg on the edges in graph g. In the message-passing scheme, the incoming messages from the neighborhood of i will later be aggregated in order to update the features of node i (see aggregate_neighbors).
The function fmsg operates on batches of edges, therefore xi, xj, and e are tensors whose last dimension is the batch size, or can be named tuples of such tensors.
Arguments
g: AnAbstractGNNGraph.xi: An array or a named tuple containing arrays whose last dimension's size isg.num_nodes. It will be appropriately materialized on the target node of each edge (see alsoedge_index).xj: Asxi, but now to be materialized on each edge's source node.e: An array or a named tuple containing arrays whose last dimension's size isg.num_edges.fmsg: A function that takes as inputs the edge-materializedxi,xj, ande. These are arrays (or named tuples of arrays) whose last dimension' size is the size of a batch of edges. The output offhas to be an array (or a named tuple of arrays) with the same batch size. If alsolayeris passed to propagate, the signature offmsghas to befmsg(layer, xi, xj, e)instead offmsg(xi, xj, e).
See also propagate and aggregate_neighbors.
GNNlib.aggregate_neighbors — Functionaggregate_neighbors(g, aggr, m)Given a graph g, edge features m, and an aggregation operator aggr (e.g +, min, max, mean), returns the new node features
\[\mathbf{x}_i = \square_{j \in \mathcal{N}(i)} \mathbf{m}_{j\to i}\]
Neighborhood aggregation is the second step of propagate, where it comes after apply_edges.
GNNlib.propagate — Functionpropagate(fmsg, g, aggr; [xi, xj, e])
propagate(fmsg, g, aggr xi, xj, e=nothing)Performs message passing on graph g. Takes care of materializing the node features on each edge, applying the message function fmsg, and returning an aggregated message $\bar{\mathbf{m}}$ (depending on the return value of fmsg, an array or a named tuple of arrays with last dimension's size g.num_nodes).
It can be decomposed in two steps:
m = apply_edges(fmsg, g, xi, xj, e)
m̄ = aggregate_neighbors(g, aggr, m)GNN layers typically call propagate in their forward pass, providing as input f a closure.
Arguments
g: AGNNGraph.xi: An array or a named tuple containing arrays whose last dimension's size isg.num_nodes. It will be appropriately materialized on the target node of each edge (see alsoedge_index).xj: Asxj, but to be materialized on edges' sources.e: An array or a named tuple containing arrays whose last dimension's size isg.num_edges.fmsg: A generic function that will be passed over toapply_edges. Has to take as inputs the edge-materializedxi,xj, ande(arrays or named tuples of arrays whose last dimension' size is the size of a batch of edges). Its output has to be an array or a named tuple of arrays with the same batch size. If alsolayeris passed to propagate, the signature offmsghas to befmsg(layer, xi, xj, e)instead offmsg(xi, xj, e).aggr: Neighborhood aggregation operator. Use+,mean,max, ormin.
Examples
using GraphNeuralNetworks, Flux
struct GNNConv <: GNNLayer
W
b
σ
end
Flux.@layer GNNConv
function GNNConv(ch::Pair{Int,Int}, σ=identity)
in, out = ch
W = Flux.glorot_uniform(out, in)
b = zeros(Float32, out)
GNNConv(W, b, σ)
end
function (l::GNNConv)(g::GNNGraph, x::AbstractMatrix)
message(xi, xj, e) = l.W * xj
m̄ = propagate(message, g, +, xj=x)
return l.σ.(m̄ .+ l.bias)
end
l = GNNConv(10 => 20)
l(g, x)See also apply_edges and aggregate_neighbors.
Built-in message functions
GNNlib.copy_xi — Functioncopy_xi(xi, xj, e) = xiGNNlib.copy_xj — Functioncopy_xj(xi, xj, e) = xjGNNlib.xi_dot_xj — Functionxi_dot_xj(xi, xj, e) = sum(xi .* xj, dims=1)GNNlib.xi_sub_xj — Functionxi_sub_xj(xi, xj, e) = xi .- xjGNNlib.xj_sub_xi — Functionxj_sub_xi(xi, xj, e) = xj .- xiGNNlib.e_mul_xj — Functione_mul_xj(xi, xj, e) = reshape(e, (...)) .* xjReshape e into a broadcast compatible shape with xj (by prepending singleton dimensions) then perform broadcasted multiplication.
GNNlib.w_mul_xj — Functionw_mul_xj(xi, xj, w) = reshape(w, (...)) .* xjSimilar to e_mul_xj but specialized on scalar edge features (weights).