Node Classification with Graph Neural Networks
In this tutorial, we will be learning how to use Graph Neural Networks (GNNs) for node classification. Given the ground-truth labels of only a small subset of nodes, and want to infer the labels for all the remaining nodes (transductive learning).
Import
Let us start off by importing some libraries. We will be using Flux.jl and GraphNeuralNetworks.jl
for our tutorial.
begin
using MLDatasets
using GraphNeuralNetworks
using Flux
using Flux: onecold, onehotbatch, logitcrossentropy
using Plots
using PlutoUI
using TSne
using Random
using Statistics
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"
Random.seed!(17) # for reproducibility
end;
Visualize
We want to visualize the the outputs of the results using t-distributed stochastic neighbor embedding (tsne) to embed our output embeddings onto a 2D plane.
function visualize_tsne(out, targets)
z = tsne(out, 2)
scatter(z[:, 1], z[:, 2], color = Int.(targets[1:size(z, 1)]), leg = false)
end
visualize_tsne (generic function with 1 method)
Dataset: Cora
For our tutorial, we will be using the Cora
dataset. Cora
is a citation network of 2708 documents classified into one of seven classes and 5429 links. Each node represent articles/documents and the edges between these nodes if one of them cite each other.
Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words.
This dataset was first introduced by Yang et al. (2016) as one of the datasets of the Planetoid
benchmark suite. We will be using MLDatasets.jl for an easy access to this dataset.
dataset = Cora()
dataset Cora: metadata => Dict{String, Any} with 3 entries graphs => 1-element Vector{MLDatasets.Graph}
Datasets in MLDatasets.jl have metadata
containing information about the dataset itself.
dataset.metadata
Dict{String, Any} with 3 entries: "name" => "cora" "classes" => [1, 2, 3, 4, 5, 6, 7] "num_classes" => 7
The graphs
variable GraphDataset contains the graph. The Cora
dataset contains only 1 graph.
dataset.graphs
1-element Vector{MLDatasets.Graph}: Graph(2708, 10556)
There is only one graph of the dataset. The node_data
contains features
indicating if certain words are present or not and targets
indicating the class for each document. We convert the single-graph dataset to a GNNGraph
.
g = mldataset2gnngraph(dataset)
GNNGraph: num_nodes: 2708 num_edges: 10556 ndata: val_mask = 2708-element BitVector targets = 2708-element Vector{Int64} test_mask = 2708-element BitVector features = 1433×2708 Matrix{Float32} train_mask = 2708-element BitVector
with_terminal() do
# Gather some statistics about the graph.
println("Number of nodes: $(g.num_nodes)")
println("Number of edges: $(g.num_edges)")
println("Average node degree: $(g.num_edges / g.num_nodes)")
println("Number of training nodes: $(sum(g.ndata.train_mask))")
println("Training node label rate: $(mean(g.ndata.train_mask))")
# println("Has isolated nodes: $(has_isolated_nodes(g))")
println("Has self-loops: $(has_self_loops(g))")
println("Is undirected: $(is_bidirected(g))")
end
Number of nodes: 2708 Number of edges: 10556 Average node degree: 3.8980797636632203 Number of training nodes: 140 Training node label rate: 0.051698670605613 Has self-loops: false Is undirected: true
Overall, this dataset is quite similar to the previously used KarateClub
network. We can see that the Cora
network holds 2,708 nodes and 10,556 edges, resulting in an average node degree of 3.9. For training this dataset, we are given the ground-truth categories of 140 nodes (20 for each class). This results in a training node label rate of only 5%.
We can further see that this network is undirected, and that there exists no isolated nodes (each document has at least one citation).
begin
x = g.ndata.features
# we onehot encode both the node labels (what we want to predict):
y = onehotbatch(g.ndata.targets, 1:7)
train_mask = g.ndata.train_mask
num_features = size(x)[1]
hidden_channels = 16
num_classes = dataset.metadata["num_classes"]
end;
Multi-layer Perception Network (MLP)
In theory, we should be able to infer the category of a document solely based on its content, i.e. its bag-of-words feature representation, without taking any relational information into account.
Let's verify that by constructing a simple MLP that solely operates on input node features (using shared weights across all nodes):
begin
struct MLP
layers::NamedTuple
end
Flux.@layer :expand MLP
function MLP(num_features, num_classes, hidden_channels; drop_rate = 0.5)
layers = (hidden = Dense(num_features => hidden_channels),
drop = Dropout(drop_rate),
classifier = Dense(hidden_channels => num_classes))
return MLP(layers)
end
function (model::MLP)(x::AbstractMatrix)
l = model.layers
x = l.hidden(x)
x = relu(x)
x = l.drop(x)
x = l.classifier(x)
return x
end
end
Training a Multilayer Perceptron
Our MLP is defined by two linear layers and enhanced by ReLU non-linearity and Dropout. Here, we first reduce the 1433-dimensional feature vector to a low-dimensional embedding (hidden_channels=16
), while the second linear layer acts as a classifier that should map each low-dimensional node embedding to one of the 7 classes.
Let's train our simple MLP by following a similar procedure as described in the first part of this tutorial. We again make use of the cross entropy loss and Adam optimizer. This time, we also define a accuracy
function to evaluate how well our final model performs on the test node set (which labels have not been observed during training).
function train(model::MLP, data::AbstractMatrix, epochs::Int, opt)
Flux.trainmode!(model)
for epoch in 1:epochs
loss, grad = Flux.withgradient(model) do model
ŷ = model(data)
logitcrossentropy(ŷ[:, train_mask], y[:, train_mask])
end
Flux.update!(opt, model, grad[1])
if epoch % 200 == 0
@show epoch, loss
end
end
end
train (generic function with 1 method)
function accuracy(model::MLP, x::AbstractMatrix, y::Flux.OneHotArray, mask::BitVector)
Flux.testmode!(model)
mean(onecold(model(x))[mask] .== onecold(y)[mask])
end
accuracy (generic function with 1 method)
begin
mlp = MLP(num_features, num_classes, hidden_channels)
opt_mlp = Flux.setup(Adam(1e-3), mlp)
epochs = 2000
train(mlp, g.ndata.features, epochs, opt_mlp)
end
After training the model, we can call the accuracy
function to see how well our model performs on unseen labels. Here, we are interested in the accuracy of the model, i.e., the ratio of correctly classified nodes:
accuracy(mlp, g.ndata.features, y, .!train_mask)
0.45872274143302183
As one can see, our MLP performs rather bad with only about 47% test accuracy. But why does the MLP do not perform better? The main reason for that is that this model suffers from heavy overfitting due to only having access to a small amount of training nodes, and therefore generalizes poorly to unseen node representations.
It also fails to incorporate an important bias into the model: Cited papers are very likely related to the category of a document. That is exactly where Graph Neural Networks come into play and can help to boost the performance of our model.
Training a Graph Convolutional Neural Network (GNN)
Following-up on the first part of this tutorial, we replace the Dense
linear layers by the GCNConv
module. To recap, the GCN layer (Kipf et al. (2017)) is defined as
$$\mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \sum_{w \in \mathcal{N}(v) \, \cup \, \{ v \}} \frac{1}{c_{w,v}} \cdot \mathbf{x}_w^{(\ell)}$$
where \(\mathbf{W}^{(\ell + 1)}\) denotes a trainable weight matrix of shape [num_output_features, num_input_features]
and \(c_{w,v}\) refers to a fixed normalization coefficient for each edge. In contrast, a single Linear
layer is defined as
$$\mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \mathbf{x}_v^{(\ell)}$$
which does not make use of neighboring node information.
begin
struct GCN
layers::NamedTuple
end
Flux.@layer GCN # provides parameter collection, gpu movement and more
function GCN(num_features, num_classes, hidden_channels; drop_rate = 0.5)
layers = (conv1 = GCNConv(num_features => hidden_channels),
drop = Dropout(drop_rate),
conv2 = GCNConv(hidden_channels => num_classes))
return GCN(layers)
end
function (gcn::GCN)(g::GNNGraph, x::AbstractMatrix)
l = gcn.layers
x = l.conv1(g, x)
x = relu.(x)
x = l.drop(x)
x = l.conv2(g, x)
return x
end
end
Now let's visualize the node embeddings of our untrained GCN network.
begin
gcn = GCN(num_features, num_classes, hidden_channels)
h_untrained = gcn(g, x) |> transpose
visualize_tsne(h_untrained, g.ndata.targets)
end
We certainly can do better by training our model. The training and testing procedure is once again the same, but this time we make use of the node features x
and the graph g
as input to our GCN model.
function train(model::GCN, g::GNNGraph, x::AbstractMatrix, epochs::Int, opt)
Flux.trainmode!(model)
for epoch in 1:epochs
loss, grad = Flux.withgradient(model) do model
ŷ = model(g, x)
logitcrossentropy(ŷ[:, train_mask], y[:, train_mask])
end
Flux.update!(opt, model, grad[1])
if epoch % 200 == 0
@show epoch, loss
end
end
end
train (generic function with 2 methods)
function accuracy(model::GCN, g::GNNGraph, x::AbstractMatrix, y::Flux.OneHotArray,
mask::BitVector)
Flux.testmode!(model)
mean(onecold(model(g, x))[mask] .== onecold(y)[mask])
end
accuracy (generic function with 2 methods)
begin
opt_gcn = Flux.setup(Adam(1e-2), gcn)
train(gcn, g, x, epochs, opt_gcn)
end
Now let's evaluate the loss of our trained GCN.
with_terminal() do
train_accuracy = accuracy(gcn, g, g.ndata.features, y, train_mask)
test_accuracy = accuracy(gcn, g, g.ndata.features, y, .!train_mask)
println("Train accuracy: $(train_accuracy)")
println("Test accuracy: $(test_accuracy)")
end
Train accuracy: 1.0 Test accuracy: 0.7706386292834891
There it is! By simply swapping the linear layers with GNN layers, we can reach 75.77% of test accuracy! This is in stark contrast to the 59% of test accuracy obtained by our MLP, indicating that relational information plays a crucial role in obtaining better performance.
We can also verify that once again by looking at the output embeddings of our trained model, which now produces a far better clustering of nodes of the same category.
begin
Flux.testmode!(gcn) # inference mode
out_trained = gcn(g, x) |> transpose
visualize_tsne(out_trained, g.ndata.targets)
end
(Optional) Exercises
To achieve better model performance and to avoid overfitting, it is usually a good idea to select the best model based on an additional validation set. The
Cora
dataset provides a validation node set asg.ndata.val_mask
, but we haven't used it yet. Can you modify the code to select and test the model with the highest validation performance? This should bring test performance to 82% accuracy.How does
GCN
behave when increasing the hidden feature dimensionality or the number of layers? Does increasing the number of layers help at all?You can try to use different GNN layers to see how model performance changes. What happens if you swap out all
GCNConv
instances withGATConv
layers that make use of attention? Try to write a 2-layerGAT
model that makes use of 8 attention heads in the first layer and 1 attention head in the second layer, uses adropout
ratio of0.6
inside and outside eachGATConv
call, and uses ahidden_channels
dimensions of8
per head.
Conclusion
In this tutorial, we have seen how to apply GNNs to real-world problems, and, in particular, how they can effectively be used for boosting a model's performance. In the next tutorial, we will look into how GNNs can be used for the task of graph classification.