I borrowed some code from Introduction to Graph Neural Nets with JAX/jraph and pytorch-geometric tutorials.
@inproceedings{Fey/Lenssen/2019,
title={Fast Graph Representation Learning with {PyTorch Geometric}},
author={Fey, Matthias and Lenssen, Jan E.},
booktitle={ICLR Workshop on Representation Learning on Graphs and Manifolds},
year={2019},
}
@software{jraph2020github,
author = {Jonathan Godwin* and Thomas Keck* and Peter Battaglia and Victor Bapst and Thomas Kipf and Yujia Li and Kimberly Stachenfeld and Petar Veli\v{c}kovi\'{c} and Alvaro Sanchez-Gonzalez},
title = {Jraph: A library for graph neural networks in jax.},
url = {http://github.com/deepmind/jraph},
version = {0.0.1.dev},
year = {2020},
}
Here are the libraries you need to code GCN in pytorch or jax:
The original GCN algorithm uses two layers as shown in equation (9) by Kipf et al. (2017) in https://arxiv.org/abs/1609.02907
# install pytorch_geometric
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
# install jraph and jax dependencies
!pip install git+https://github.com/deepmind/jraph.git
# Standard libraries
import numpy as np
from scipy import sparse
import seaborn as sns
import pandas as pd
# Plotting libraries
import matplotlib.pyplot as plt
import networkx as nx
from matplotlib import cm
from IPython.display import Javascript # Restrict height of output cell.
# sklearn
from sklearn.manifold import TSNE
# PyTorch
import torch
import torch.nn.functional as F
from torch.nn import Linear
# PyTorch geometric
from torch_geometric.utils import from_scipy_sparse_matrix
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import (Planetoid, KarateClub)
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.data import Data
from torch_geometric import seed_everything
# jax
import jax
import jax.numpy as jnp
import jax.tree_util as tree
import flax
import flax.linen as nn
from flax.training import train_state
import optax
import pickle
# jraph
import jraph
from jraph._src import models as jraph_models
random_seed = 42
plt.style.use('dark_background')
plot_colors = cm.tab10.colors
accuracy_list = []
You can find a description of this dataset in Pytorch-Geometric documentation. In case of split='full'
, all nodes except those in the validation and test sets will be used for training.
dataset = Planetoid(root='data/Planetoid', name='Cora', split='full', transform=NormalizeFeatures())
num_features = dataset.num_features
num_classes = dataset.num_classes
data_Cora = dataset[0] # Get the first graph object.
data_Cora
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
graph = jraph.GraphsTuple(
n_node=jnp.asarray([data_Cora.x.shape[0]]),
n_edge=jnp.asarray([data_Cora.edge_index.shape[1]]),
nodes=jnp.asarray(data_Cora.x),
# No edge features.
edges=None,
globals=None,
senders=jnp.asarray([data_Cora.edge_index[0,:]]).squeeze(),
receivers=jnp.asarray([data_Cora.edge_index[1,:]]).squeeze())
graph_train_mask = jnp.asarray([data_Cora.train_mask]).squeeze()
graph_val_mask = jnp.asarray([data_Cora.val_mask]).squeeze()
graph_test_mask = jnp.asarray([data_Cora.test_mask]).squeeze()
graph_labels = jnp.asarray([data_Cora.y]).squeeze()
one_hot_labels = jax.nn.one_hot(graph_labels, len(jnp.unique(graph_labels)))
Let’s check if we have the right number of nodes and edges. We also want to check if we have the correct data types float32
instead of torch.float32
print(f'Number of nodes: {graph.n_node[0]}')
print(f'Number of edges: {graph.n_edge[0]}')
print(f'Feature matrix data type: {graph.nodes.dtype}')
print(f'senders list data type: {graph.senders.dtype}')
print(f'receivers list data type: {graph.receivers.dtype}')
print(f'Labels matrix data type: {graph_labels.dtype}')
Number of nodes: 2708
Number of edges: 10556
Feature matrix data type: float32
senders list data type: int32
receivers list data type: int32
Labels matrix data type: int32
# Functions must be passed to jraph GNNs, but pytype does not recognise
# linen Modules as callables to here we wrap in a function.
def make_embed_fn(latent_size):
def embed(inputs):
return nn.Dense(latent_size)(inputs)
return embed
def _attention_logit_fn(sender_attr: jnp.ndarray,
receiver_attr: jnp.ndarray,
edges: jnp.ndarray) -> jnp.ndarray:
del edges
x = jnp.concatenate((sender_attr, receiver_attr), axis=1)
return nn.Dense(1)(x)
class GCN(nn.Module):
"""Defines a GAT network using FLAX
Args:
graph: GraphsTuple the network processes.
Returns:
output graph with updated node values.
"""
gcn1_output_dim: int
output_dim: int
@nn.compact
def __call__(self, x):
gcn1 = jraph.GraphConvolution(update_node_fn=lambda n: jax.nn.relu(make_embed_fn(self.gcn1_output_dim)(n)),
add_self_edges=True)
gcn2 = jraph.GraphConvolution(update_node_fn=nn.Dense(self.output_dim))
return gcn2(gcn1(x))
model = GCN(8, len(jnp.unique(graph_labels)))
model
GCN(
# attributes
gcn1_output_dim = 8
output_dim = 7
)
We set the optimizer to adam using optax
library. Then we initialized the model using random parameters.
optimizer = optax.adam(learning_rate=0.01)
rng, inp_rng, init_rng = jax.random.split(jax.random.PRNGKey(random_seed), 3)
params = model.init(jax.random.PRNGKey(random_seed),graph)
model_state = train_state.TrainState.create(apply_fn=model.apply,
params=params,
tx=optimizer)
def compute_loss(state, params, graph, labels, one_hot_labels, mask):
"""Computes loss."""
pred_graph = state.apply_fn(params, graph)
preds = jax.nn.log_softmax(pred_graph.nodes)
loss = optax.softmax_cross_entropy(preds, one_hot_labels)
loss_mask = jnp.sum(jnp.where(mask, loss, 0)) / jnp.sum(mask)
pred_labels = jnp.argmax(preds, axis=1)
acc = (pred_labels == labels)
acc_mask = jnp.sum(jnp.where(mask, acc, 0)) / jnp.sum(mask)
return loss_mask, acc_mask
@jax.jit # Jit the function for efficiency
def train_step(state, graph, graph_labels, one_hot_labels, train_mask):
# Gradient function
grad_fn = jax.value_and_grad(compute_loss, # Function to calculate the loss
argnums=1, # Parameters are second argument of the function
has_aux=True # Function has additional outputs, here accuracy
)
# Determine gradients for current model, parameters and batch
(loss, acc), grads = grad_fn(state, state.params, graph, graph_labels, one_hot_labels, train_mask)
# Perform parameter update with gradients and optimizer
state = state.apply_gradients(grads=grads)
# Return state and any other value we might want
return state, loss, acc
def train_model(state, graph, graph_labels, one_hot_labels, train_mask, val_mask, num_epochs):
# Training loop
for epoch in range(num_epochs):
state, loss, acc = train_step(state, graph, graph_labels, one_hot_labels, train_mask)
val_loss, val_acc = compute_loss(state, state.params, graph, graph_labels, one_hot_labels, val_mask)
print(f'step: {epoch:03d}, train loss: {loss:.4f}, train acc: {acc:.4f}, val loss: {val_loss:.4f}, val acc: {val_acc:.4f}')
return state, acc, val_acc
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
trained_model_state, train_acc, val_acc = train_model(model_state, graph, graph_labels, one_hot_labels, graph_train_mask, graph_val_mask, num_epochs=200)
accuracy_list.append(['Cora', 'train', float(train_acc)])
accuracy_list.append(['Cora', 'valid', float(val_acc)])
step: 000, train loss: 1.9462, train acc: 0.1209, val loss: 1.9386, val acc: 0.3100
step: 001, train loss: 1.9387, train acc: 0.2972, val loss: 1.9307, val acc: 0.3600
step: 002, train loss: 1.9309, train acc: 0.3295, val loss: 1.9215, val acc: 0.3720
step: 003, train loss: 1.9218, train acc: 0.3659, val loss: 1.9113, val acc: 0.3760
step: 004, train loss: 1.9118, train acc: 0.3651, val loss: 1.9013, val acc: 0.3600
step: 005, train loss: 1.9020, train acc: 0.3336, val loss: 1.8914, val acc: 0.3280
step: 006, train loss: 1.8924, train acc: 0.3030, val loss: 1.8811, val acc: 0.3180
step: 007, train loss: 1.8824, train acc: 0.2914, val loss: 1.8707, val acc: 0.3160
step: 008, train loss: 1.8723, train acc: 0.2881, val loss: 1.8601, val acc: 0.3160
step: 009, train loss: 1.8620, train acc: 0.2856, val loss: 1.8494, val acc: 0.3160
step: 010, train loss: 1.8516, train acc: 0.2848, val loss: 1.8388, val acc: 0.3160
...
...
...
step: 191, train loss: 0.2367, train acc: 0.9487, val loss: 0.4670, val acc: 0.8560
step: 192, train loss: 0.2349, train acc: 0.9495, val loss: 0.4665, val acc: 0.8540
step: 193, train loss: 0.2332, train acc: 0.9503, val loss: 0.4660, val acc: 0.8540
step: 194, train loss: 0.2315, train acc: 0.9503, val loss: 0.4655, val acc: 0.8540
step: 195, train loss: 0.2298, train acc: 0.9512, val loss: 0.4650, val acc: 0.8540
step: 196, train loss: 0.2281, train acc: 0.9512, val loss: 0.4646, val acc: 0.8540
step: 197, train loss: 0.2264, train acc: 0.9520, val loss: 0.4641, val acc: 0.8540
step: 198, train loss: 0.2248, train acc: 0.9520, val loss: 0.4637, val acc: 0.8540
step: 199, train loss: 0.2232, train acc: 0.9520, val loss: 0.4633, val acc: 0.8540
test_loss, test_acc = compute_loss(trained_model_state, trained_model_state.params, graph, graph_labels, one_hot_labels, graph_test_mask)
print(f'test loss: {test_loss:.4f}, test acc: {test_acc:.4f}')
accuracy_list.append(['Cora', 'test', float(test_acc)])
test loss: 0.4664, test acc: 0.8550
nodes_untrained = model_state.apply_fn(model_state.params, graph).nodes
z_nodes_untrained = TSNE(n_components=2).fit_transform(nodes_untrained)
nodes_trained = trained_model_state.apply_fn(trained_model_state.params, graph).nodes
z_nodes_trained = TSNE(n_components=2).fit_transform(nodes_trained)
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].set_title('Before training')
axs[0].scatter(z_nodes_untrained[:,0], z_nodes_untrained[:,1], marker='o', color=np.array(plot_colors)[graph_labels])
axs[1].set_title('After training')
axs[1].scatter(z_nodes_trained[:,0], z_nodes_trained[:,1], marker='o', color=np.array(plot_colors)[graph_labels])
for ax in axs:
ax.tick_params(axis='both',which='both',bottom=False,top=False,left=False,right=False,
labelbottom=False,labeltop=False,labelleft=False,labelright=False);
ax.set(xlabel=None, ylabel=None)
plt.show()
You can find a description of this dataset in Pytorch-Geometric documentation. In case of split='full'
, all nodes except those in the validation and test sets will be used for training.
dataset = Planetoid(root='data/Planetoid', name='CiteSeer', split='full', transform=NormalizeFeatures())
num_features = dataset.num_features
num_classes = dataset.num_classes
data_CiteSeer = dataset[0] # Get the first graph object.
data_CiteSeer
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.test.index
Processing...
Done!
Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])
graph = jraph.GraphsTuple(
n_node=jnp.asarray([data_CiteSeer.x.shape[0]]),
n_edge=jnp.asarray([data_CiteSeer.edge_index.shape[1]]),
nodes=jnp.asarray(data_CiteSeer.x),
# No edge features.
edges=None,
globals=None,
senders=jnp.asarray([data_CiteSeer.edge_index[0,:]]).squeeze(),
receivers=jnp.asarray([data_CiteSeer.edge_index[1,:]]).squeeze())
graph_train_mask = jnp.asarray([data_CiteSeer.train_mask]).squeeze()
graph_val_mask = jnp.asarray([data_CiteSeer.val_mask]).squeeze()
graph_test_mask = jnp.asarray([data_CiteSeer.test_mask]).squeeze()
graph_labels = jnp.asarray([data_CiteSeer.y]).squeeze()
one_hot_labels = jax.nn.one_hot(graph_labels, len(jnp.unique(graph_labels)))
model = GCN(8, len(jnp.unique(graph_labels)))
model
GCN(
# attributes
gcn1_output_dim = 8
output_dim = 6
)
optimizer = optax.adam(learning_rate=0.01)
rng, inp_rng, init_rng = jax.random.split(jax.random.PRNGKey(random_seed), 3)
params = model.init(jax.random.PRNGKey(random_seed),graph)
model_state = train_state.TrainState.create(apply_fn=model.apply,
params=params,
tx=optimizer)
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
trained_model_state, train_acc, val_acc = train_model(model_state, graph, graph_labels, one_hot_labels, graph_train_mask, graph_val_mask, num_epochs=200)
accuracy_list.append(['CiteSeer', 'train', float(train_acc)])
accuracy_list.append(['CiteSeer', 'valid', float(val_acc)])
step: 000, train loss: 1.7918, train acc: 0.1932, val loss: 1.7849, val acc: 0.2080
step: 001, train loss: 1.7857, train acc: 0.1970, val loss: 1.7781, val acc: 0.2080
step: 002, train loss: 1.7796, train acc: 0.1970, val loss: 1.7709, val acc: 0.2080
step: 003, train loss: 1.7730, train acc: 0.1970, val loss: 1.7628, val acc: 0.2100
step: 004, train loss: 1.7654, train acc: 0.1987, val loss: 1.7548, val acc: 0.2160
step: 005, train loss: 1.7579, train acc: 0.2053, val loss: 1.7472, val acc: 0.2200
step: 006, train loss: 1.7507, train acc: 0.2080, val loss: 1.7400, val acc: 0.2180
step: 007, train loss: 1.7439, train acc: 0.2063, val loss: 1.7328, val acc: 0.2180
step: 008, train loss: 1.7370, train acc: 0.2063, val loss: 1.7255, val acc: 0.2180
step: 009, train loss: 1.7300, train acc: 0.2091, val loss: 1.7182, val acc: 0.2240
step: 010, train loss: 1.7230, train acc: 0.2195, val loss: 1.7109, val acc: 0.2420
...
...
...
step: 191, train loss: 0.3845, train acc: 0.8659, val loss: 0.7996, val acc: 0.7560
step: 192, train loss: 0.3827, train acc: 0.8664, val loss: 0.8007, val acc: 0.7560
step: 193, train loss: 0.3809, train acc: 0.8670, val loss: 0.8018, val acc: 0.7540
step: 194, train loss: 0.3791, train acc: 0.8681, val loss: 0.8029, val acc: 0.7540
step: 195, train loss: 0.3773, train acc: 0.8686, val loss: 0.8040, val acc: 0.7540
step: 196, train loss: 0.3756, train acc: 0.8686, val loss: 0.8051, val acc: 0.7560
step: 197, train loss: 0.3738, train acc: 0.8697, val loss: 0.8063, val acc: 0.7560
step: 198, train loss: 0.3721, train acc: 0.8714, val loss: 0.8074, val acc: 0.7580
step: 199, train loss: 0.3704, train acc: 0.8719, val loss: 0.8086, val acc: 0.7580
test_loss, test_acc = compute_loss(trained_model_state, trained_model_state.params, graph, graph_labels, one_hot_labels, graph_test_mask)
print(f'test loss: {test_loss:.4f}, test acc: {test_acc:.4f}')
accuracy_list.append(['CiteSeer', 'test', float(test_acc)])
test loss: 0.7778, test acc: 0.7680
nodes_untrained = model_state.apply_fn(model_state.params, graph).nodes
z_nodes_untrained = TSNE(n_components=2).fit_transform(nodes_untrained)
nodes_trained = trained_model_state.apply_fn(trained_model_state.params, graph).nodes
z_nodes_trained = TSNE(n_components=2).fit_transform(nodes_trained)
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].set_title('Before training')
axs[0].scatter(z_nodes_untrained[:,0], z_nodes_untrained[:,1], marker='o', color=np.array(plot_colors)[graph_labels])
axs[1].set_title('After training')
axs[1].scatter(z_nodes_trained[:,0], z_nodes_trained[:,1], marker='o', color=np.array(plot_colors)[graph_labels])
for ax in axs:
ax.tick_params(axis='both',which='both',bottom=False,top=False,left=False,right=False,
labelbottom=False,labeltop=False,labelleft=False,labelright=False);
ax.set(xlabel=None, ylabel=None)
plt.show()
df = pd.DataFrame(accuracy_list, columns=('Dataset', 'Split', 'Accuracy'))
sns.barplot(df,x='Dataset', y='Accuracy', hue='Split', palette="muted")
plt.show()