Pooling Layers

GNNLux.GlobalAttentionPoolType
GlobalAttentionPool(fgate, ffeat=identity)

Global soft attention layer from the Gated Graph Sequence Neural Networks paper

\[\mathbf{u}_V = \sum_{i\in V} \alpha_i\, f_{feat}(\mathbf{x}_i)\]

where the coefficients $\alpha_i$ are given by a GNNLib.softmax_nodes operation:

\[\alpha_i = \frac{e^{f_{gate}(\mathbf{x}_i)}} {\sum_{i'\in V} e^{f_{gate}(\mathbf{x}_{i'})}}.\]

Arguments

  • fgate: The function $f_{gate}: \mathbb{R}^{D_{in}} \to \mathbb{R}$. It is typically expressed by a neural network.

  • ffeat: The function $f_{feat}: \mathbb{R}^{D_{in}} \to \mathbb{R}^{D_{out}}$. It is typically expressed by a neural network.

Examples

using Graphs, LuxCore, Lux, GNNLux, Random

rng = Random.default_rng()
chin = 6
chout = 5    

fgate = Dense(chin, 1)
ffeat = Dense(chin, chout)
pool = GlobalAttentionPool(fgate, ffeat)

g = batch([GNNGraph(Graphs.random_regular_graph(10, 4), 
                         ndata=rand(Float32, chin, 10)) 
                for i=1:3])

ps = (fgate = LuxCore.initialparameters(rng, fgate), ffeat = LuxCore.initialparameters(rng, ffeat))
st = (fgate = LuxCore.initialstates(rng, fgate), ffeat = LuxCore.initialstates(rng, ffeat))

u, st = pool(g, g.ndata.x, ps, st)

@assert size(u) == (chout, g.num_graphs)
source
GNNLux.GlobalPoolType
GlobalPool(aggr)

Global pooling layer for graph neural networks. Takes a graph and feature nodes as inputs and performs the operation

\[\mathbf{u}_V = \square_{i \in V} \mathbf{x}_i\]

where $V$ is the set of nodes of the input graph and the type of aggregation represented by $\square$ is selected by the aggr argument. Commonly used aggregations are mean, max, and +.

See also GNNlib.reduce_nodes.

Examples

using Lux, GNNLux, Graphs, MLUtils

using Graphs
pool = GlobalPool(mean)

g = GNNGraph(erdos_renyi(10, 4))
X = rand(32, 10)
pool(g, X, ps, st) # => 32x1 matrix


g = MLUtils.batch([GNNGraph(erdos_renyi(10, 4)) for _ in 1:5])
X = rand(32, 50)
pool(g, X, ps, st) # => 32x5 matrix
source
GNNLux.TopKPoolType
TopKPool(adj, k, in_channel)

Top-k pooling layer.

Arguments

  • adj: Adjacency matrix of a graph.
  • k: Top-k nodes are selected to pool together.
  • in_channel: The dimension of input channel.
source