Node Classification with Graph Neural Networks
In this tutorial, we will be learning how to use Graph Neural Networks (GNNs) for node classification. Given the ground-truth labels of only a small subset of nodes, we want to infer the labels for all the remaining nodes (transductive learning).
Import
Let us start off by importing some libraries. We will be using Lux.jl
and GNNLux.jl
for our tutorial.
using Lux, GNNLux
using MLDatasets
using Plots, TSne
using Random, Statistics
using Zygote, Optimisers, OneHotArrays, ConcreteStructs
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # don't ask for dataset download confirmation
rng = Random.seed!(17); # for reproducibility
Visualize
We want to visualize our results using t-distributed stochastic neighbor embedding (tsne) to project our output onto a 2D plane.
function visualize_tsne(out, targets)
z = tsne(out, 2)
scatter(z[:, 1], z[:, 2], color = Int.(targets[1:size(z, 1)]), leg = false)
end;
Dataset: Cora
For our tutorial, we will be using the Cora
dataset. Cora
is a citation network of 2708 documents categorized into seven classes with 5,429 citation links. Each node represents an article or document, and edges between nodes indicate a citation relationship, where one cites the other.
Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words.
This dataset was first introduced by Yang et al. (2016) as one of the datasets of the Planetoid
benchmark suite. We will be using MLDatasets.jl for an easy access to this dataset.
dataset = Cora()
dataset Cora:
metadata => Dict{String, Any} with 3 entries
graphs => 1-element Vector{MLDatasets.Graph}
Datasets in MLDatasets.jl have metadata
containing information about the dataset itself.
dataset.metadata
Dict{String, Any} with 3 entries:
"name" => "cora"
"classes" => [1, 2, 3, 4, 5, 6, 7]
"num_classes" => 7
The graphs
variable contains the graph. The Cora
dataset contains only 1 graph.
dataset.graphs
1-element Vector{MLDatasets.Graph}:
Graph(2708, 10556)
There is only one graph of the dataset. The node_data
contains features
indicating if certain words are present or not and targets
indicating the class for each document. We convert the single-graph dataset to a GNNGraph
.
g = mldataset2gnngraph(dataset)
println("Number of nodes: $(g.num_nodes)")
println("Number of edges: $(g.num_edges)")
println("Average node degree: $(g.num_edges / g.num_nodes)")
println("Number of training nodes: $(sum(g.ndata.train_mask))")
println("Training node label rate: $(mean(g.ndata.train_mask))")
println("Has isolated nodes: $(has_isolated_nodes(g))")
println("Has self-loops: $(has_self_loops(g))")
println("Is undirected: $(is_bidirected(g))")
Number of nodes: 2708
Number of edges: 10556
Average node degree: 3.8980797636632203
Number of training nodes: 140
Training node label rate: 0.051698670605613
Has isolated nodes: false
Has self-loops: false
Is undirected: true
Overall, this dataset is quite similar to the previously used KarateClub
network. We can see that the Cora
network holds 2,708 nodes and 10,556 edges, resulting in an average node degree of 3.9. For training this dataset, we are given the ground-truth categories of 140 nodes (20 for each class). This results in a training node label rate of only 5%.
We can further see that this network is undirected, and that there exists no isolated nodes (each document has at least one citation).
x = g.ndata.features # we onehot encode the node labels (what we want to predict):
y = onehotbatch(g.ndata.targets, 1:7)
train_mask = g.ndata.train_mask;
num_features = size(x)[1];
hidden_channels = 16;
drop_rate = 0.5;
num_classes = dataset.metadata["num_classes"];
Multi-layer Perception Network (MLP)
In theory, we should be able to infer the category of a document solely based on its content, i.e. its bag-of-words feature representation, without taking any relational information into account.
Let's verify that by constructing a simple MLP that solely operates on input node features (using shared weights across all nodes):
MLP = Chain(Dense(num_features => hidden_channels, relu),
Dropout(drop_rate),
Dense(hidden_channels => num_classes))
ps, st = Lux.setup(rng, MLP);
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore ~/.julia/packages/LuxCore/GlbG3/src/LuxCore.jl:18
Training a Multilayer Perceptron
Our MLP is defined by two linear layers and enhanced by ReLU non-linearity and Dropout. Here, we first reduce the 1433-dimensional feature vector to a low-dimensional embedding (hidden_channels=16
), while the second linear layer acts as a classifier that should map each low-dimensional node embedding to one of the 7 classes.
Let's train our simple MLP by following a similar procedure as described in the first part of this tutorial. We again make use of the cross entropy loss and Adam optimizer. This time, we also define a accuracy
function to evaluate how well our final model performs on the test node set (which labels have not been observed during training).
function loss(model, ps, st, x)
logitcrossentropy = CrossEntropyLoss(; logits=Val(true))
ŷ, st = model(x, ps, st)
return logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]), (st), 0
end
function train_model!(MLP, ps, st, x, epochs)
train_state = Lux.Training.TrainState(MLP, ps, st, Adam(1e-3))
for iter in 1:epochs
_, loss_value, _, train_state = Lux.Training.single_train_step!(AutoZygote(), loss, x, train_state)
if iter % 100 == 0
println("Epoch: $(iter) Loss: $(loss_value)")
end
end
end
function accuracy(model, x, ps, st, y, mask)
st = Lux.testmode(st)
ŷ, st = model(x, ps, st)
mean(onecold(ŷ)[mask] .== onecold(y)[mask])
end
train_model!(MLP, ps, st, x, 2000)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore ~/.julia/packages/LuxCore/GlbG3/src/LuxCore.jl:18
Epoch: 100 Loss: 0.810594
Epoch: 200 Loss: 0.48982772
Epoch: 300 Loss: 0.31716076
Epoch: 400 Loss: 0.2397098
Epoch: 500 Loss: 0.20041731
Epoch: 600 Loss: 0.11589075
Epoch: 700 Loss: 0.21093586
Epoch: 800 Loss: 0.18869051
Epoch: 900 Loss: 0.15322906
Epoch: 1000 Loss: 0.12451931
Epoch: 1100 Loss: 0.13396983
Epoch: 1200 Loss: 0.111468166
Epoch: 1300 Loss: 0.17113678
Epoch: 1400 Loss: 0.18155631
Epoch: 1500 Loss: 0.17731342
Epoch: 1600 Loss: 0.11386197
Epoch: 1700 Loss: 0.09408201
Epoch: 1800 Loss: 0.15806198
Epoch: 1900 Loss: 0.104388796
Epoch: 2000 Loss: 0.18465123
After training the model, we can call the accuracy
function to see how well our model performs on unseen labels. Here, we are interested in the accuracy of the model, i.e., the ratio of correctly classified nodes:
accuracy(MLP, x, ps, st, y, .!train_mask)
0.5089563862928349
As one can see, our MLP performs rather bad with only about ~50% test accuracy. But why does the MLP do not perform better? The main reason for that is that this model suffers from heavy overfitting due to only having access to a small amount of training nodes, and therefore generalizes poorly to unseen node representations.
It also fails to incorporate an important bias into the model: Cited papers are very likely related to the category of a document. That is exactly where Graph Neural Networks come into play and can help to boost the performance of our model.
Training a Graph Convolutional Neural Network (GNN)
Following-up on the first part of this tutorial, we replace the Dense
linear layers by the GCNConv
module. To recap, the GCN layer (Kipf et al. (2017)) is defined as
\[\mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \sum_{w \in \mathcal{N}(v) \, \cup \, \{ v \}} \frac{1}{c_{w,v}} \cdot \mathbf{x}_w^{(\ell)}\]
where $\mathbf{W}^{(\ell + 1)}$ denotes a trainable weight matrix of shape [num_output_features, num_input_features]
and $c_{w,v}$ refers to a fixed normalization coefficient for each edge. In contrast, a single Linear
layer is defined as
\[\mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \mathbf{x}_v^{(\ell)}\]
which does not make use of neighboring node information.
@concrete struct GCN <: GNNContainerLayer{(:conv1, :drop, :conv2)}
nf::Int
nc::Int
hd::Int
conv1
conv2
drop
use_bias::Bool
init_weight
init_bias
end;
function GCN(num_features, num_classes, hidden_channels, drop_rate; use_bias = true, init_weight = glorot_uniform, init_bias = zeros32) # constructor
conv1 = GCNConv(num_features => hidden_channels)
conv2 = GCNConv(hidden_channels => num_classes)
drop = Dropout(drop_rate)
return GCN(num_features, num_classes, hidden_channels, conv1, conv2, drop, use_bias, init_weight, init_bias)
end;
function (gcn::GCN)(g::GNNGraph, x, ps, st) # forward pass
x, stconv1 = gcn.conv1(g, x, ps.conv1, st.conv1)
x = relu.(x)
x, stdrop = gcn.drop(x, ps.drop, st.drop)
x, stconv2 = gcn.conv2(g, x, ps.conv2, st.conv2)
return x, (conv1 = stconv1, drop = stdrop, conv2 = stconv2)
end;
Now let's visualize the node embeddings of our untrained GCN network.
gcn = GCN(num_features, num_classes, hidden_channels, drop_rate)
ps, st = Lux.setup(rng, gcn)
h_untrained, st = gcn(g, x, ps, st)
h_untrained = h_untrained |> transpose
visualize_tsne(h_untrained, g.ndata.targets)
We certainly can do better by training our model. The training and testing procedure is once again the same, but this time we make use of the node features x
and the graph g
as input to our GCN model.
function loss(gcn, ps, st, tuple)
g, x, y = tuple
logitcrossentropy = CrossEntropyLoss(; logits=Val(true))
ŷ, st = gcn(g, x, ps, st)
return logitcrossentropy(ŷ[:, train_mask], y[:, train_mask]), (st), 0
end
function train_model!(gcn, ps, st, g, x, y)
train_state = Lux.Training.TrainState(gcn, ps, st, Adam(1e-2))
for iter in 1:2000
_, loss_value, _, train_state = Lux.Training.single_train_step!(AutoZygote(), loss,(g, x, y), train_state)
if iter % 100 == 0
println("Epoch: $(iter) Loss: $(loss_value)")
end
end
return gcn, ps, st
end
gcn, ps, st = train_model!(gcn, ps, st, g, x, y);
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore ~/.julia/packages/LuxCore/GlbG3/src/LuxCore.jl:18
Epoch: 100 Loss: 0.019381031
Epoch: 200 Loss: 0.017426146
Epoch: 300 Loss: 0.006051709
Epoch: 400 Loss: 0.0015434261
Epoch: 500 Loss: 0.0052008606
Epoch: 600 Loss: 0.025294377
Epoch: 700 Loss: 0.0012917791
Epoch: 800 Loss: 0.005089373
Epoch: 900 Loss: 0.00912053
Epoch: 1000 Loss: 0.002442247
Epoch: 1100 Loss: 0.00024606875
Epoch: 1200 Loss: 0.00046606906
Epoch: 1300 Loss: 0.002437515
Epoch: 1400 Loss: 0.00019191795
Epoch: 1500 Loss: 0.0056298207
Epoch: 1600 Loss: 0.00020503976
Epoch: 1700 Loss: 0.0028860446
Epoch: 1800 Loss: 0.02319943
Epoch: 1900 Loss: 0.00030635786
Epoch: 2000 Loss: 0.00013437525
Now let's evaluate the loss of our trained GCN.
function accuracy(model, g, x, ps, st, y, mask)
st = Lux.testmode(st)
ŷ, st = model(g, x, ps, st)
mean(onecold(ŷ)[mask] .== onecold(y)[mask])
end
train_accuracy = accuracy(gcn, g, g.ndata.features, ps, st, y, train_mask)
test_accuracy = accuracy(gcn, g, g.ndata.features, ps, st, y, .!train_mask)
println("Train accuracy: $(train_accuracy)")
println("Test accuracy: $(test_accuracy)")
Train accuracy: 1.0
Test accuracy: 0.7636292834890965
There it is! By simply swapping the linear layers with GNN layers, we can reach 76% of test accuracy! This is in stark contrast to the 50% of test accuracy obtained by our MLP, indicating that relational information plays a crucial role in obtaining better performance.
We can also verify that once again by looking at the output embeddings of our trained model, which now produces a far better clustering of nodes of the same category.
st = Lux.testmode(st) # inference mode
out_trained, st = gcn(g, x, ps, st)
out_trained = out_trained|> transpose
visualize_tsne(out_trained, g.ndata.targets)
(Optional) Exercises
To achieve better model performance and to avoid overfitting, it is usually a good idea to select the best model based on an additional validation set. The
Cora
dataset provides a validation node set asg.ndata.val_mask
, but we haven't used it yet. Can you modify the code to select and test the model with the highest validation performance? This should bring test performance to > 80% accuracy.How does
GCN
behave when increasing the hidden feature dimensionality or the number of layers? Does increasing the number of layers help at all?You can try to use different GNN layers to see how model performance changes. What happens if you swap out all
GCNConv
instances withGATConv
layers that make use of attention? Try to write a 2-layerGAT
model that makes use of 8 attention heads in the first layer and 1 attention head in the second layer, uses adropout
ratio of0.6
inside and outside eachGATConv
call, and uses ahidden_channels
dimensions of8
per head.
Conclusion
In this tutorial, we have seen how to apply GNNs to real-world problems, and, in particular, how they can effectively be used for boosting a model's performance. In the next tutorial, we will look into how GNNs can be used for the task of graph classification.
This page was generated using Literate.jl.