Message Passing

Index

Interface

GNNlib.apply_edgesFunction
apply_edges(fmsg, g; [xi, xj, e])
apply_edges(fmsg, g, xi, xj, e=nothing)

Returns the message from node j to node i applying the message function fmsg on the edges in graph g. In the message-passing scheme, the incoming messages from the neighborhood of i will later be aggregated in order to update the features of node i (see aggregate_neighbors).

The function fmsg operates on batches of edges, therefore xi, xj, and e are tensors whose last dimension is the batch size, or can be named tuples of such tensors.

Arguments

  • g: An AbstractGNNGraph.
  • xi: An array or a named tuple containing arrays whose last dimension's size is g.num_nodes. It will be appropriately materialized on the target node of each edge (see also edge_index).
  • xj: As xi, but now to be materialized on each edge's source node.
  • e: An array or a named tuple containing arrays whose last dimension's size is g.num_edges.
  • fmsg: A function that takes as inputs the edge-materialized xi, xj, and e. These are arrays (or named tuples of arrays) whose last dimension' size is the size of a batch of edges. The output of f has to be an array (or a named tuple of arrays) with the same batch size. If also layer is passed to propagate, the signature of fmsg has to be fmsg(layer, xi, xj, e) instead of fmsg(xi, xj, e).

See also propagate and aggregate_neighbors.

source
GNNlib.aggregate_neighborsFunction
aggregate_neighbors(g, aggr, m)

Given a graph g, edge features m, and an aggregation operator aggr (e.g +, min, max, mean), returns the new node features

\[\mathbf{x}_i = \square_{j \in \mathcal{N}(i)} \mathbf{m}_{j\to i}\]

Neighborhood aggregation is the second step of propagate, where it comes after apply_edges.

source
GNNlib.propagateFunction
propagate(fmsg, g, aggr; [xi, xj, e])
propagate(fmsg, g, aggr xi, xj, e=nothing)

Performs message passing on graph g. Takes care of materializing the node features on each edge, applying the message function fmsg, and returning an aggregated message $\bar{\mathbf{m}}$ (depending on the return value of fmsg, an array or a named tuple of arrays with last dimension's size g.num_nodes).

It can be decomposed in two steps:

m = apply_edges(fmsg, g, xi, xj, e)
m̄ = aggregate_neighbors(g, aggr, m)

GNN layers typically call propagate in their forward pass, providing as input f a closure.

Arguments

  • g: A GNNGraph.
  • xi: An array or a named tuple containing arrays whose last dimension's size is g.num_nodes. It will be appropriately materialized on the target node of each edge (see also edge_index).
  • xj: As xj, but to be materialized on edges' sources.
  • e: An array or a named tuple containing arrays whose last dimension's size is g.num_edges.
  • fmsg: A generic function that will be passed over to apply_edges. Has to take as inputs the edge-materialized xi, xj, and e (arrays or named tuples of arrays whose last dimension' size is the size of a batch of edges). Its output has to be an array or a named tuple of arrays with the same batch size. If also layer is passed to propagate, the signature of fmsg has to be fmsg(layer, xi, xj, e) instead of fmsg(xi, xj, e).
  • aggr: Neighborhood aggregation operator. Use +, mean, max, or min.

Examples

using GraphNeuralNetworks, Flux

struct GNNConv <: GNNLayer
    W
    b
    σ
end

Flux.@layer GNNConv

function GNNConv(ch::Pair{Int,Int}, σ=identity)
    in, out = ch
    W = Flux.glorot_uniform(out, in)
    b = zeros(Float32, out)
    GNNConv(W, b, σ)
end

function (l::GNNConv)(g::GNNGraph, x::AbstractMatrix)
    message(xi, xj, e) = l.W * xj
    m̄ = propagate(message, g, +, xj=x)
    return l.σ.(m̄ .+ l.bias)
end

l = GNNConv(10 => 20)
l(g, x)

See also apply_edges and aggregate_neighbors.

source

Built-in message functions

GNNlib.e_mul_xjFunction
e_mul_xj(xi, xj, e) = reshape(e, (...)) .* xj

Reshape e into broadcast compatible shape with xj (by prepending singleton dimensions) then perform broadcasted multiplication.

source
GNNlib.w_mul_xjFunction
w_mul_xj(xi, xj, w) = reshape(w, (...)) .* xj

Similar to e_mul_xj but specialized on scalar edge features (weights).

source