Mashaan blog

Graph Attention Networks in JAX

Acknowledgment:

I borrowed some code from Introduction to Graph Neural Nets with JAX/jraph and Jraph’s train_flax.py.

References:

@software{jax2018github,
 author   = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
 title    = {JAX: composable transformations of Python+NumPy programs},
 url      = {http://github.com/google/jax},
 version  = {0.3.13},
 year     = {2018},
}
@article{velickovic2018graph,
 title    ="{Graph Attention Networks}",
author    ={Petar Veličković and Guillem Cucurull and Arantxa Casanova and Adriana Romero and Pietro Liò and Yoshua Bengio},
journal   ={International Conference on Learning Representations},
year      ={2018},
}

Import libraries

# install pytorch_geometric
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

# install jraph and jax dependencies
!pip install git+https://github.com/deepmind/jraph.git
# Standard libraries
import numpy as np
from scipy import sparse
import seaborn as sns
import pandas as pd

# Plotting libraries
import matplotlib.pyplot as plt
import networkx as nx
from matplotlib import cm
from IPython.display import Javascript  # Restrict height of output cell.

# PyTorch
import torch
import torch.nn.functional as F
from torch.nn import Linear

# PyTorch geometric
from torch_geometric.utils import from_scipy_sparse_matrix
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import (Planetoid, KarateClub)
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.data import Data
from torch_geometric import seed_everything

# jax
import jax
import jax.numpy as jnp
import jax.tree_util as tree
import flax
import flax.linen as nn
from flax.training import train_state
import optax
import pickle

# jraph
import jraph
from jraph._src import models as jraph_models
random_seed = 42
plt.style.use('dark_background')
accuracy_list = []

Import Cora Dataset

You can find a description of this dataset in Pytorch-Geometric documentation. In case of split='full', all nodes except those in the validation and test sets will be used for training.

dataset = Planetoid(root='data/Planetoid', name='Cora', split='full', transform=NormalizeFeatures())
num_features = dataset.num_features
num_classes = dataset.num_classes
data_Cora = dataset[0]  # Get the first graph object.
data_Cora
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
graph = jraph.GraphsTuple(
      n_node=jnp.asarray([data_Cora.x.shape[0]]),
      n_edge=jnp.asarray([data_Cora.edge_index.shape[1]]),
      nodes=jnp.asarray(data_Cora.x),
      # No edge features.
      edges=None,
      globals=None,
      senders=jnp.asarray([data_Cora.edge_index[0,:]]).squeeze(),
      receivers=jnp.asarray([data_Cora.edge_index[1,:]]).squeeze())

graph_train_mask = jnp.asarray([data_Cora.train_mask]).squeeze()
graph_val_mask = jnp.asarray([data_Cora.val_mask]).squeeze()
graph_test_mask = jnp.asarray([data_Cora.test_mask]).squeeze()
graph_labels = jnp.asarray([data_Cora.y]).squeeze()
one_hot_labels = jax.nn.one_hot(graph_labels, len(jnp.unique(graph_labels)))

Let’s check if we have the right number of nodes and edges. We also want to check if we have the correct data types float32 instead of torch.float32

print(f'Number of nodes: {graph.n_node[0]}')
print(f'Number of edges: {graph.n_edge[0]}')
print(f'Feature matrix data type: {graph.nodes.dtype}')
print(f'senders list data type:   {graph.senders.dtype}')
print(f'receivers list data type: {graph.receivers.dtype}')
print(f'Labels matrix data type:  {graph_labels.dtype}')
Number of nodes: 2708
Number of edges: 10556
Feature matrix data type: float32
senders list data type:   int32
receivers list data type: int32
Labels matrix data type:  int32

GAT Layers from Jraph

# Functions must be passed to jraph GNNs, but pytype does not recognise
# linen Modules as callables to here we wrap in a function.
def make_embed_fn(latent_size):
  def embed(inputs):
    return nn.Dense(latent_size)(inputs)
  return embed

def _attention_logit_fn(sender_attr: jnp.ndarray,
                        receiver_attr: jnp.ndarray,
                        edges: jnp.ndarray) -> jnp.ndarray:
  del edges
  x = jnp.concatenate((sender_attr, receiver_attr), axis=1)
  return nn.Dense(1)(x)
class GAT(nn.Module):
  """Defines a GAT network using FLAX

  Args:
    graph: GraphsTuple the network processes.

  Returns:
    output graph with updated node values.
  """
  gat1_output_dim: int
  gat2_output_dim: int
  output_dim: int

  @nn.compact
  def __call__(self, x):
    gat1 = jraph.GAT(attention_query_fn=lambda n: make_embed_fn(self.gat1_output_dim)(n),
                          attention_logit_fn=_attention_logit_fn,
                          node_update_fn=None)
    gat2 = jraph.GAT(attention_query_fn=lambda n: make_embed_fn(self.gat2_output_dim)(n),
                          attention_logit_fn=_attention_logit_fn,
                          node_update_fn=nn.Dense(self.output_dim))
    return gat2(gat1(x))

You might ask where did we applied nonlinearity?

by setting node_update_fn=None we are allowing jraph.GAT to apply the default nonlinear function, which is jax.nn.leaky_relu(x) as you can see from this code snippet taken from the source code:

image

model = GAT(8, 8, len(jnp.unique(graph_labels)))
model
GAT(
    # attributes
    gat1_output_dim = 8
    gat2_output_dim = 8
    output_dim = 7
)

Optimizer and Loss

We set the optimizer to adam using optax library. Then we initialized the model using random parameters.

optimizer = optax.adam(learning_rate=0.01)

rng, inp_rng, init_rng = jax.random.split(jax.random.PRNGKey(random_seed), 3)
params = model.init(jax.random.PRNGKey(random_seed),graph)

model_state = train_state.TrainState.create(apply_fn=model.apply,
                                            params=params,
                                            tx=optimizer)

The problem of masking under @jax.jit

For the loss function, we used cross entropy inside compute_loss function to evaluate the model predictions. The problem arises when we try to mask the model output for each train\val\test separately. We cannot use Boolean indexing under @jax.jit, it will throw NonConcreteBooleanIndexError which is shown in the screenshot below. You can read about this error here. Jax’s jit compilation requires that all array shapes to be known at compile time. This can’t be done with Boolean indexing.

image

So we have to get back to the basics!!

To avoid the error above, Jax recommends using jnp.where. We passed the mask to compute_loss function and we compute the losses without masking. We used jnp.where on the loss array to set the masked nodes to zero, then we sum up the rest. This will gives us the sum needed to compute the mean. For count, we sum up the mask entries. Because it is only ones and zeros it will return how many nodes in the mask. Dividing sum by count gives us the mean loss for the nodes in the mask.

image

def compute_loss(state, params, graph, labels, one_hot_labels, mask):
  """Computes loss."""
  pred_graph = state.apply_fn(params, graph)
  preds = jax.nn.log_softmax(pred_graph.nodes)
  loss = optax.softmax_cross_entropy(preds, one_hot_labels)
  loss_mask = jnp.sum(jnp.where(mask, loss, 0)) / jnp.sum(mask)

  pred_labels = jnp.argmax(preds, axis=1)
  acc = (pred_labels == labels)
  acc_mask = jnp.sum(jnp.where(mask, acc, 0)) / jnp.sum(mask)
  return loss_mask, acc_mask

Training

@jax.jit  # Jit the function for efficiency
def train_step(state, graph, graph_labels, one_hot_labels, train_mask):
  # Gradient function
  grad_fn = jax.value_and_grad(compute_loss,  # Function to calculate the loss
                                argnums=1,  # Parameters are second argument of the function
                                has_aux=True  # Function has additional outputs, here accuracy
                              )
  # Determine gradients for current model, parameters and batch
  (loss, acc), grads = grad_fn(state, state.params, graph, graph_labels, one_hot_labels, train_mask)
  # Perform parameter update with gradients and optimizer
  state = state.apply_gradients(grads=grads)
  # Return state and any other value we might want
  return state, loss, acc
def train_model(state, graph, graph_labels, one_hot_labels, train_mask, val_mask, num_epochs):
  # Training loop
  for epoch in range(num_epochs):
    state, loss, acc = train_step(state, graph, graph_labels, one_hot_labels, train_mask)
    val_loss, val_acc = compute_loss(state, state.params, graph, graph_labels, one_hot_labels, val_mask)
    print(f'step: {epoch:03d}, train loss: {loss:.4f}, train acc: {acc:.4f}, val loss: {val_loss:.4f}, val acc: {val_acc:.4f}')
  return state, acc, val_acc

image

display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
trained_model_state, train_acc, val_acc = train_model(model_state, graph, graph_labels, one_hot_labels, graph_train_mask, graph_val_mask, num_epochs=100)
accuracy_list.append(['Cora', 'train', float(train_acc)])
accuracy_list.append(['Cora', 'valid', float(val_acc)])
step: 000, train loss: 1.9455, train acc: 0.2012, val loss: 1.9298, val acc: 0.3140
step: 001, train loss: 1.9314, train acc: 0.2839, val loss: 1.9131, val acc: 0.3160
step: 002, train loss: 1.9167, train acc: 0.2823, val loss: 1.8954, val acc: 0.3160
step: 003, train loss: 1.9012, train acc: 0.2823, val loss: 1.8772, val acc: 0.3160
step: 004, train loss: 1.8854, train acc: 0.2823, val loss: 1.8588, val acc: 0.3160
step: 005, train loss: 1.8696, train acc: 0.2823, val loss: 1.8407, val acc: 0.3160
step: 006, train loss: 1.8542, train acc: 0.2823, val loss: 1.8232, val acc: 0.3160
step: 007, train loss: 1.8394, train acc: 0.2823, val loss: 1.8067, val acc: 0.3160
step: 008, train loss: 1.8256, train acc: 0.2823, val loss: 1.7916, val acc: 0.3160
step: 009, train loss: 1.8131, train acc: 0.2823, val loss: 1.7780, val acc: 0.3160
step: 010, train loss: 1.8019, train acc: 0.2823, val loss: 1.7656, val acc: 0.3160
...
...
...
step: 091, train loss: 0.2185, train acc: 0.9421, val loss: 0.5830, val acc: 0.8520
step: 092, train loss: 0.2118, train acc: 0.9454, val loss: 0.5841, val acc: 0.8500
step: 093, train loss: 0.2053, train acc: 0.9478, val loss: 0.5855, val acc: 0.8500
step: 094, train loss: 0.1990, train acc: 0.9487, val loss: 0.5871, val acc: 0.8500
step: 095, train loss: 0.1930, train acc: 0.9487, val loss: 0.5889, val acc: 0.8500
step: 096, train loss: 0.1871, train acc: 0.9503, val loss: 0.5911, val acc: 0.8480
step: 097, train loss: 0.1814, train acc: 0.9520, val loss: 0.5934, val acc: 0.8500
step: 098, train loss: 0.1758, train acc: 0.9553, val loss: 0.5960, val acc: 0.8480
step: 099, train loss: 0.1705, train acc: 0.9561, val loss: 0.5988, val acc: 0.8460

Testing

test_loss, test_acc = compute_loss(trained_model_state, trained_model_state.params, graph, graph_labels, one_hot_labels, graph_test_mask)
print(f'test loss: {test_loss:.4f}, test acc: {test_acc:.4f}')
accuracy_list.append(['Cora', 'test', float(test_acc)])
test loss: 0.6685, test acc: 0.8170

Import CiteSeer Dataset

You can find a description of this dataset in Pytorch-Geometric documentation. In case of split='full', all nodes except those in the validation and test sets will be used for training.

dataset = Planetoid(root='data/Planetoid', name='CiteSeer', split='full', transform=NormalizeFeatures())
num_features = dataset.num_features
num_classes = dataset.num_classes
data_CiteSeer = dataset[0]  # Get the first graph object.
data_CiteSeer
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.test.index
Processing...
Done!
Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])
graph = jraph.GraphsTuple(
      n_node=jnp.asarray([data_CiteSeer.x.shape[0]]),
      n_edge=jnp.asarray([data_CiteSeer.edge_index.shape[1]]),
      nodes=jnp.asarray(data_CiteSeer.x),
      # No edge features.
      edges=None,
      globals=None,
      senders=jnp.asarray([data_CiteSeer.edge_index[0,:]]).squeeze(),
      receivers=jnp.asarray([data_CiteSeer.edge_index[1,:]]).squeeze())

graph_train_mask = jnp.asarray([data_CiteSeer.train_mask]).squeeze()
graph_val_mask = jnp.asarray([data_CiteSeer.val_mask]).squeeze()
graph_test_mask = jnp.asarray([data_CiteSeer.test_mask]).squeeze()
graph_labels = jnp.asarray([data_CiteSeer.y]).squeeze()
one_hot_labels = jax.nn.one_hot(graph_labels, len(jnp.unique(graph_labels)))
model = GAT(8, 8, len(jnp.unique(graph_labels)))
model
GAT(
    # attributes
    gat1_output_dim = 8
    gat2_output_dim = 8
    output_dim = 6
)
optimizer = optax.adam(learning_rate=0.01)

rng, inp_rng, init_rng = jax.random.split(jax.random.PRNGKey(random_seed), 3)
params = model.init(jax.random.PRNGKey(random_seed),graph)

model_state = train_state.TrainState.create(apply_fn=model.apply,
                                            params=params,
                                            tx=optimizer)
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
trained_model_state, train_acc, val_acc = train_model(model_state, graph, graph_labels, one_hot_labels, graph_train_mask, graph_val_mask, num_epochs=100)
accuracy_list.append(['CiteSeer', 'train', float(train_acc)])
accuracy_list.append(['CiteSeer', 'valid', float(val_acc)])
step: 000, train loss: 1.7920, train acc: 0.1007, val loss: 1.7826, val acc: 0.2080
step: 001, train loss: 1.7842, train acc: 0.1959, val loss: 1.7733, val acc: 0.2080
step: 002, train loss: 1.7769, train acc: 0.1959, val loss: 1.7640, val acc: 0.2080
step: 003, train loss: 1.7695, train acc: 0.1959, val loss: 1.7547, val acc: 0.2080
step: 004, train loss: 1.7622, train acc: 0.1959, val loss: 1.7454, val acc: 0.2080
step: 005, train loss: 1.7551, train acc: 0.1959, val loss: 1.7361, val acc: 0.2080
step: 006, train loss: 1.7481, train acc: 0.1959, val loss: 1.7272, val acc: 0.2080
step: 007, train loss: 1.7413, train acc: 0.1959, val loss: 1.7184, val acc: 0.2080
step: 008, train loss: 1.7346, train acc: 0.1959, val loss: 1.7099, val acc: 0.2080
step: 009, train loss: 1.7278, train acc: 0.1959, val loss: 1.7015, val acc: 0.2100
step: 010, train loss: 1.7207, train acc: 0.1976, val loss: 1.6932, val acc: 0.2920
...
...
...
step: 091, train loss: 0.2316, train acc: 0.9217, val loss: 0.9846, val acc: 0.7100
step: 092, train loss: 0.2263, train acc: 0.9228, val loss: 0.9931, val acc: 0.7100
step: 093, train loss: 0.2211, train acc: 0.9234, val loss: 1.0016, val acc: 0.7080
step: 094, train loss: 0.2162, train acc: 0.9239, val loss: 1.0100, val acc: 0.7100
step: 095, train loss: 0.2114, train acc: 0.9261, val loss: 1.0182, val acc: 0.7040
step: 096, train loss: 0.2069, train acc: 0.9267, val loss: 1.0264, val acc: 0.7040
step: 097, train loss: 0.2024, train acc: 0.9288, val loss: 1.0347, val acc: 0.7020
step: 098, train loss: 0.1982, train acc: 0.9278, val loss: 1.0435, val acc: 0.7000
step: 099, train loss: 0.1941, train acc: 0.9283, val loss: 1.0527, val acc: 0.7000
test_loss, test_acc = compute_loss(trained_model_state, trained_model_state.params, graph, graph_labels, one_hot_labels, graph_test_mask)
print(f'test loss: {test_loss:.4f}, test acc: {test_acc:.4f}')
accuracy_list.append(['CiteSeer', 'test', float(test_acc)])
test loss: 1.0631, test acc: 0.6960

Plotting the results

df = pd.DataFrame(accuracy_list, columns=('Dataset', 'Split', 'Accuracy'))
sns.barplot(df,x='Dataset', y='Accuracy', hue='Split', palette="muted")
plt.show()

image