I borrowed some code from Introduction to Graph Neural Nets with JAX/jraph and Jraph’s train_flax.py.
@software{jax2018github,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {JAX: composable transformations of Python+NumPy programs},
url = {http://github.com/google/jax},
version = {0.3.13},
year = {2018},
}
@article{velickovic2018graph,
title ="{Graph Attention Networks}",
author ={Petar Veličković and Guillem Cucurull and Arantxa Casanova and Adriana Romero and Pietro Liò and Yoshua Bengio},
journal ={International Conference on Learning Representations},
year ={2018},
}
# install pytorch_geometric
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
# install jraph and jax dependencies
!pip install git+https://github.com/deepmind/jraph.git
# Standard libraries
import numpy as np
from scipy import sparse
import seaborn as sns
import pandas as pd
# Plotting libraries
import matplotlib.pyplot as plt
import networkx as nx
from matplotlib import cm
from IPython.display import Javascript # Restrict height of output cell.
# PyTorch
import torch
import torch.nn.functional as F
from torch.nn import Linear
# PyTorch geometric
from torch_geometric.utils import from_scipy_sparse_matrix
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import (Planetoid, KarateClub)
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.data import Data
from torch_geometric import seed_everything
# jax
import jax
import jax.numpy as jnp
import jax.tree_util as tree
import flax
import flax.linen as nn
from flax.training import train_state
import optax
import pickle
# jraph
import jraph
from jraph._src import models as jraph_models
random_seed = 42
plt.style.use('dark_background')
accuracy_list = []
You can find a description of this dataset in Pytorch-Geometric documentation. In case of split='full'
, all nodes except those in the validation and test sets will be used for training.
dataset = Planetoid(root='data/Planetoid', name='Cora', split='full', transform=NormalizeFeatures())
num_features = dataset.num_features
num_classes = dataset.num_classes
data_Cora = dataset[0] # Get the first graph object.
data_Cora
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
graph = jraph.GraphsTuple(
n_node=jnp.asarray([data_Cora.x.shape[0]]),
n_edge=jnp.asarray([data_Cora.edge_index.shape[1]]),
nodes=jnp.asarray(data_Cora.x),
# No edge features.
edges=None,
globals=None,
senders=jnp.asarray([data_Cora.edge_index[0,:]]).squeeze(),
receivers=jnp.asarray([data_Cora.edge_index[1,:]]).squeeze())
graph_train_mask = jnp.asarray([data_Cora.train_mask]).squeeze()
graph_val_mask = jnp.asarray([data_Cora.val_mask]).squeeze()
graph_test_mask = jnp.asarray([data_Cora.test_mask]).squeeze()
graph_labels = jnp.asarray([data_Cora.y]).squeeze()
one_hot_labels = jax.nn.one_hot(graph_labels, len(jnp.unique(graph_labels)))
Let’s check if we have the right number of nodes and edges. We also want to check if we have the correct data types float32
instead of torch.float32
print(f'Number of nodes: {graph.n_node[0]}')
print(f'Number of edges: {graph.n_edge[0]}')
print(f'Feature matrix data type: {graph.nodes.dtype}')
print(f'senders list data type: {graph.senders.dtype}')
print(f'receivers list data type: {graph.receivers.dtype}')
print(f'Labels matrix data type: {graph_labels.dtype}')
Number of nodes: 2708
Number of edges: 10556
Feature matrix data type: float32
senders list data type: int32
receivers list data type: int32
Labels matrix data type: int32
# Functions must be passed to jraph GNNs, but pytype does not recognise
# linen Modules as callables to here we wrap in a function.
def make_embed_fn(latent_size):
def embed(inputs):
return nn.Dense(latent_size)(inputs)
return embed
def _attention_logit_fn(sender_attr: jnp.ndarray,
receiver_attr: jnp.ndarray,
edges: jnp.ndarray) -> jnp.ndarray:
del edges
x = jnp.concatenate((sender_attr, receiver_attr), axis=1)
return nn.Dense(1)(x)
class GAT(nn.Module):
"""Defines a GAT network using FLAX
Args:
graph: GraphsTuple the network processes.
Returns:
output graph with updated node values.
"""
gat1_output_dim: int
gat2_output_dim: int
output_dim: int
@nn.compact
def __call__(self, x):
gat1 = jraph.GAT(attention_query_fn=lambda n: make_embed_fn(self.gat1_output_dim)(n),
attention_logit_fn=_attention_logit_fn,
node_update_fn=None)
gat2 = jraph.GAT(attention_query_fn=lambda n: make_embed_fn(self.gat2_output_dim)(n),
attention_logit_fn=_attention_logit_fn,
node_update_fn=nn.Dense(self.output_dim))
return gat2(gat1(x))
You might ask where did we applied nonlinearity?
by setting node_update_fn=None
we are allowing jraph.GAT
to apply the default nonlinear function, which is jax.nn.leaky_relu(x)
as you can see from this code snippet taken from the source code:
model = GAT(8, 8, len(jnp.unique(graph_labels)))
model
GAT(
# attributes
gat1_output_dim = 8
gat2_output_dim = 8
output_dim = 7
)
We set the optimizer to adam using optax
library. Then we initialized the model using random parameters.
optimizer = optax.adam(learning_rate=0.01)
rng, inp_rng, init_rng = jax.random.split(jax.random.PRNGKey(random_seed), 3)
params = model.init(jax.random.PRNGKey(random_seed),graph)
model_state = train_state.TrainState.create(apply_fn=model.apply,
params=params,
tx=optimizer)
@jax.jit
For the loss function, we used cross entropy inside compute_loss
function to evaluate the model predictions. The problem arises when we try to mask the model output for each train\val\test
separately. We cannot use Boolean indexing under @jax.jit
, it will throw NonConcreteBooleanIndexError
which is shown in the screenshot below. You can read about this error here. Jax’s jit compilation requires that all array shapes to be known at compile time. This can’t be done with Boolean indexing.
So we have to get back to the basics!!
To avoid the error above, Jax recommends using jnp.where
. We passed the mask to compute_loss
function and we compute the losses without masking. We used jnp.where
on the loss array to set the masked nodes to zero, then we sum up the rest. This will gives us the sum
needed to compute the mean
. For count
, we sum up the mask entries. Because it is only ones and zeros it will return how many nodes in the mask. Dividing sum
by count
gives us the mean loss for the nodes in the mask.
def compute_loss(state, params, graph, labels, one_hot_labels, mask):
"""Computes loss."""
pred_graph = state.apply_fn(params, graph)
preds = jax.nn.log_softmax(pred_graph.nodes)
loss = optax.softmax_cross_entropy(preds, one_hot_labels)
loss_mask = jnp.sum(jnp.where(mask, loss, 0)) / jnp.sum(mask)
pred_labels = jnp.argmax(preds, axis=1)
acc = (pred_labels == labels)
acc_mask = jnp.sum(jnp.where(mask, acc, 0)) / jnp.sum(mask)
return loss_mask, acc_mask
@jax.jit # Jit the function for efficiency
def train_step(state, graph, graph_labels, one_hot_labels, train_mask):
# Gradient function
grad_fn = jax.value_and_grad(compute_loss, # Function to calculate the loss
argnums=1, # Parameters are second argument of the function
has_aux=True # Function has additional outputs, here accuracy
)
# Determine gradients for current model, parameters and batch
(loss, acc), grads = grad_fn(state, state.params, graph, graph_labels, one_hot_labels, train_mask)
# Perform parameter update with gradients and optimizer
state = state.apply_gradients(grads=grads)
# Return state and any other value we might want
return state, loss, acc
def train_model(state, graph, graph_labels, one_hot_labels, train_mask, val_mask, num_epochs):
# Training loop
for epoch in range(num_epochs):
state, loss, acc = train_step(state, graph, graph_labels, one_hot_labels, train_mask)
val_loss, val_acc = compute_loss(state, state.params, graph, graph_labels, one_hot_labels, val_mask)
print(f'step: {epoch:03d}, train loss: {loss:.4f}, train acc: {acc:.4f}, val loss: {val_loss:.4f}, val acc: {val_acc:.4f}')
return state, acc, val_acc
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
trained_model_state, train_acc, val_acc = train_model(model_state, graph, graph_labels, one_hot_labels, graph_train_mask, graph_val_mask, num_epochs=100)
accuracy_list.append(['Cora', 'train', float(train_acc)])
accuracy_list.append(['Cora', 'valid', float(val_acc)])
step: 000, train loss: 1.9455, train acc: 0.2012, val loss: 1.9298, val acc: 0.3140
step: 001, train loss: 1.9314, train acc: 0.2839, val loss: 1.9131, val acc: 0.3160
step: 002, train loss: 1.9167, train acc: 0.2823, val loss: 1.8954, val acc: 0.3160
step: 003, train loss: 1.9012, train acc: 0.2823, val loss: 1.8772, val acc: 0.3160
step: 004, train loss: 1.8854, train acc: 0.2823, val loss: 1.8588, val acc: 0.3160
step: 005, train loss: 1.8696, train acc: 0.2823, val loss: 1.8407, val acc: 0.3160
step: 006, train loss: 1.8542, train acc: 0.2823, val loss: 1.8232, val acc: 0.3160
step: 007, train loss: 1.8394, train acc: 0.2823, val loss: 1.8067, val acc: 0.3160
step: 008, train loss: 1.8256, train acc: 0.2823, val loss: 1.7916, val acc: 0.3160
step: 009, train loss: 1.8131, train acc: 0.2823, val loss: 1.7780, val acc: 0.3160
step: 010, train loss: 1.8019, train acc: 0.2823, val loss: 1.7656, val acc: 0.3160
...
...
...
step: 091, train loss: 0.2185, train acc: 0.9421, val loss: 0.5830, val acc: 0.8520
step: 092, train loss: 0.2118, train acc: 0.9454, val loss: 0.5841, val acc: 0.8500
step: 093, train loss: 0.2053, train acc: 0.9478, val loss: 0.5855, val acc: 0.8500
step: 094, train loss: 0.1990, train acc: 0.9487, val loss: 0.5871, val acc: 0.8500
step: 095, train loss: 0.1930, train acc: 0.9487, val loss: 0.5889, val acc: 0.8500
step: 096, train loss: 0.1871, train acc: 0.9503, val loss: 0.5911, val acc: 0.8480
step: 097, train loss: 0.1814, train acc: 0.9520, val loss: 0.5934, val acc: 0.8500
step: 098, train loss: 0.1758, train acc: 0.9553, val loss: 0.5960, val acc: 0.8480
step: 099, train loss: 0.1705, train acc: 0.9561, val loss: 0.5988, val acc: 0.8460
test_loss, test_acc = compute_loss(trained_model_state, trained_model_state.params, graph, graph_labels, one_hot_labels, graph_test_mask)
print(f'test loss: {test_loss:.4f}, test acc: {test_acc:.4f}')
accuracy_list.append(['Cora', 'test', float(test_acc)])
test loss: 0.6685, test acc: 0.8170
You can find a description of this dataset in Pytorch-Geometric documentation. In case of split='full'
, all nodes except those in the validation and test sets will be used for training.
dataset = Planetoid(root='data/Planetoid', name='CiteSeer', split='full', transform=NormalizeFeatures())
num_features = dataset.num_features
num_classes = dataset.num_classes
data_CiteSeer = dataset[0] # Get the first graph object.
data_CiteSeer
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.test.index
Processing...
Done!
Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])
graph = jraph.GraphsTuple(
n_node=jnp.asarray([data_CiteSeer.x.shape[0]]),
n_edge=jnp.asarray([data_CiteSeer.edge_index.shape[1]]),
nodes=jnp.asarray(data_CiteSeer.x),
# No edge features.
edges=None,
globals=None,
senders=jnp.asarray([data_CiteSeer.edge_index[0,:]]).squeeze(),
receivers=jnp.asarray([data_CiteSeer.edge_index[1,:]]).squeeze())
graph_train_mask = jnp.asarray([data_CiteSeer.train_mask]).squeeze()
graph_val_mask = jnp.asarray([data_CiteSeer.val_mask]).squeeze()
graph_test_mask = jnp.asarray([data_CiteSeer.test_mask]).squeeze()
graph_labels = jnp.asarray([data_CiteSeer.y]).squeeze()
one_hot_labels = jax.nn.one_hot(graph_labels, len(jnp.unique(graph_labels)))
model = GAT(8, 8, len(jnp.unique(graph_labels)))
model
GAT(
# attributes
gat1_output_dim = 8
gat2_output_dim = 8
output_dim = 6
)
optimizer = optax.adam(learning_rate=0.01)
rng, inp_rng, init_rng = jax.random.split(jax.random.PRNGKey(random_seed), 3)
params = model.init(jax.random.PRNGKey(random_seed),graph)
model_state = train_state.TrainState.create(apply_fn=model.apply,
params=params,
tx=optimizer)
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
trained_model_state, train_acc, val_acc = train_model(model_state, graph, graph_labels, one_hot_labels, graph_train_mask, graph_val_mask, num_epochs=100)
accuracy_list.append(['CiteSeer', 'train', float(train_acc)])
accuracy_list.append(['CiteSeer', 'valid', float(val_acc)])
step: 000, train loss: 1.7920, train acc: 0.1007, val loss: 1.7826, val acc: 0.2080
step: 001, train loss: 1.7842, train acc: 0.1959, val loss: 1.7733, val acc: 0.2080
step: 002, train loss: 1.7769, train acc: 0.1959, val loss: 1.7640, val acc: 0.2080
step: 003, train loss: 1.7695, train acc: 0.1959, val loss: 1.7547, val acc: 0.2080
step: 004, train loss: 1.7622, train acc: 0.1959, val loss: 1.7454, val acc: 0.2080
step: 005, train loss: 1.7551, train acc: 0.1959, val loss: 1.7361, val acc: 0.2080
step: 006, train loss: 1.7481, train acc: 0.1959, val loss: 1.7272, val acc: 0.2080
step: 007, train loss: 1.7413, train acc: 0.1959, val loss: 1.7184, val acc: 0.2080
step: 008, train loss: 1.7346, train acc: 0.1959, val loss: 1.7099, val acc: 0.2080
step: 009, train loss: 1.7278, train acc: 0.1959, val loss: 1.7015, val acc: 0.2100
step: 010, train loss: 1.7207, train acc: 0.1976, val loss: 1.6932, val acc: 0.2920
...
...
...
step: 091, train loss: 0.2316, train acc: 0.9217, val loss: 0.9846, val acc: 0.7100
step: 092, train loss: 0.2263, train acc: 0.9228, val loss: 0.9931, val acc: 0.7100
step: 093, train loss: 0.2211, train acc: 0.9234, val loss: 1.0016, val acc: 0.7080
step: 094, train loss: 0.2162, train acc: 0.9239, val loss: 1.0100, val acc: 0.7100
step: 095, train loss: 0.2114, train acc: 0.9261, val loss: 1.0182, val acc: 0.7040
step: 096, train loss: 0.2069, train acc: 0.9267, val loss: 1.0264, val acc: 0.7040
step: 097, train loss: 0.2024, train acc: 0.9288, val loss: 1.0347, val acc: 0.7020
step: 098, train loss: 0.1982, train acc: 0.9278, val loss: 1.0435, val acc: 0.7000
step: 099, train loss: 0.1941, train acc: 0.9283, val loss: 1.0527, val acc: 0.7000
test_loss, test_acc = compute_loss(trained_model_state, trained_model_state.params, graph, graph_labels, one_hot_labels, graph_test_mask)
print(f'test loss: {test_loss:.4f}, test acc: {test_acc:.4f}')
accuracy_list.append(['CiteSeer', 'test', float(test_acc)])
test loss: 1.0631, test acc: 0.6960
df = pd.DataFrame(accuracy_list, columns=('Dataset', 'Split', 'Accuracy'))
sns.barplot(df,x='Dataset', y='Accuracy', hue='Split', palette="muted")
plt.show()