Mashaan blog

Graph Convolutional Network (GCN) in JAX

Acknowledgment:

I borrowed some code from Introduction to Graph Neural Nets with JAX/jraph and pytorch-geometric tutorials.

References:

@inproceedings{Fey/Lenssen/2019,
  title={Fast Graph Representation Learning with {PyTorch Geometric}},
  author={Fey, Matthias and Lenssen, Jan E.},
  booktitle={ICLR Workshop on Representation Learning on Graphs and Manifolds},
  year={2019},
}
@software{jraph2020github,
  author = {Jonathan Godwin* and Thomas Keck* and Peter Battaglia and Victor Bapst and Thomas Kipf and Yujia Li and Kimberly Stachenfeld and Petar Veli\v{c}kovi\'{c} and Alvaro Sanchez-Gonzalez},
  title = {Jraph: A library for graph neural networks in jax.},
  url = {http://github.com/deepmind/jraph},
  version = {0.0.1.dev},
  year = {2020},
}

DL and GNN libraries

image

PyG vs Jraph

image

Here are the libraries you need to code GCN in pytorch or jax:

image

The original GCN algorithm uses two layers as shown in equation (9) by Kipf et al. (2017) in https://arxiv.org/abs/1609.02907

image

Import libraries

# install pytorch_geometric
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

# install jraph and jax dependencies
!pip install git+https://github.com/deepmind/jraph.git
# Standard libraries
import numpy as np
from scipy import sparse
import seaborn as sns
import pandas as pd

# Plotting libraries
import matplotlib.pyplot as plt
import networkx as nx
from matplotlib import cm
from IPython.display import Javascript  # Restrict height of output cell.

# sklearn
from sklearn.manifold import TSNE

# PyTorch
import torch
import torch.nn.functional as F
from torch.nn import Linear

# PyTorch geometric
from torch_geometric.utils import from_scipy_sparse_matrix
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import (Planetoid, KarateClub)
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.data import Data
from torch_geometric import seed_everything

# jax
import jax
import jax.numpy as jnp
import jax.tree_util as tree
import flax
import flax.linen as nn
from flax.training import train_state
import optax
import pickle

# jraph
import jraph
from jraph._src import models as jraph_models
random_seed = 42
plt.style.use('dark_background')
plot_colors = cm.tab10.colors
accuracy_list = []

Import Cora Dataset

You can find a description of this dataset in Pytorch-Geometric documentation. In case of split='full', all nodes except those in the validation and test sets will be used for training.

dataset = Planetoid(root='data/Planetoid', name='Cora', split='full', transform=NormalizeFeatures())
num_features = dataset.num_features
num_classes = dataset.num_classes
data_Cora = dataset[0]  # Get the first graph object.
data_Cora
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
graph = jraph.GraphsTuple(
      n_node=jnp.asarray([data_Cora.x.shape[0]]),
      n_edge=jnp.asarray([data_Cora.edge_index.shape[1]]),
      nodes=jnp.asarray(data_Cora.x),
      # No edge features.
      edges=None,
      globals=None,
      senders=jnp.asarray([data_Cora.edge_index[0,:]]).squeeze(),
      receivers=jnp.asarray([data_Cora.edge_index[1,:]]).squeeze())

graph_train_mask = jnp.asarray([data_Cora.train_mask]).squeeze()
graph_val_mask = jnp.asarray([data_Cora.val_mask]).squeeze()
graph_test_mask = jnp.asarray([data_Cora.test_mask]).squeeze()
graph_labels = jnp.asarray([data_Cora.y]).squeeze()
one_hot_labels = jax.nn.one_hot(graph_labels, len(jnp.unique(graph_labels)))

Let’s check if we have the right number of nodes and edges. We also want to check if we have the correct data types float32 instead of torch.float32

print(f'Number of nodes: {graph.n_node[0]}')
print(f'Number of edges: {graph.n_edge[0]}')
print(f'Feature matrix data type: {graph.nodes.dtype}')
print(f'senders list data type:   {graph.senders.dtype}')
print(f'receivers list data type: {graph.receivers.dtype}')
print(f'Labels matrix data type:  {graph_labels.dtype}')
Number of nodes: 2708
Number of edges: 10556
Feature matrix data type: float32
senders list data type:   int32
receivers list data type: int32
Labels matrix data type:  int32

GCN Layers from Jraph

# Functions must be passed to jraph GNNs, but pytype does not recognise
# linen Modules as callables to here we wrap in a function.
def make_embed_fn(latent_size):
  def embed(inputs):
    return nn.Dense(latent_size)(inputs)
  return embed

def _attention_logit_fn(sender_attr: jnp.ndarray,
                        receiver_attr: jnp.ndarray,
                        edges: jnp.ndarray) -> jnp.ndarray:
  del edges
  x = jnp.concatenate((sender_attr, receiver_attr), axis=1)
  return nn.Dense(1)(x)
class GCN(nn.Module):
  """Defines a GAT network using FLAX

  Args:
    graph: GraphsTuple the network processes.

  Returns:
    output graph with updated node values.
  """
  gcn1_output_dim: int
  output_dim: int

  @nn.compact
  def __call__(self, x):
    gcn1 = jraph.GraphConvolution(update_node_fn=lambda n: jax.nn.relu(make_embed_fn(self.gcn1_output_dim)(n)),
                          add_self_edges=True)
    gcn2 = jraph.GraphConvolution(update_node_fn=nn.Dense(self.output_dim))
    return gcn2(gcn1(x))
model = GCN(8, len(jnp.unique(graph_labels)))
model
GCN(
    # attributes
    gcn1_output_dim = 8
    output_dim = 7
)

Optimizer and Loss

We set the optimizer to adam using optax library. Then we initialized the model using random parameters.

optimizer = optax.adam(learning_rate=0.01)

rng, inp_rng, init_rng = jax.random.split(jax.random.PRNGKey(random_seed), 3)
params = model.init(jax.random.PRNGKey(random_seed),graph)

model_state = train_state.TrainState.create(apply_fn=model.apply,
                                            params=params,
                                            tx=optimizer)
def compute_loss(state, params, graph, labels, one_hot_labels, mask):
  """Computes loss."""
  pred_graph = state.apply_fn(params, graph)
  preds = jax.nn.log_softmax(pred_graph.nodes)
  loss = optax.softmax_cross_entropy(preds, one_hot_labels)
  loss_mask = jnp.sum(jnp.where(mask, loss, 0)) / jnp.sum(mask)

  pred_labels = jnp.argmax(preds, axis=1)
  acc = (pred_labels == labels)
  acc_mask = jnp.sum(jnp.where(mask, acc, 0)) / jnp.sum(mask)
  return loss_mask, acc_mask

Training

@jax.jit  # Jit the function for efficiency
def train_step(state, graph, graph_labels, one_hot_labels, train_mask):
  # Gradient function
  grad_fn = jax.value_and_grad(compute_loss,  # Function to calculate the loss
                                argnums=1,  # Parameters are second argument of the function
                                has_aux=True  # Function has additional outputs, here accuracy
                              )
  # Determine gradients for current model, parameters and batch
  (loss, acc), grads = grad_fn(state, state.params, graph, graph_labels, one_hot_labels, train_mask)
  # Perform parameter update with gradients and optimizer
  state = state.apply_gradients(grads=grads)
  # Return state and any other value we might want
  return state, loss, acc
def train_model(state, graph, graph_labels, one_hot_labels, train_mask, val_mask, num_epochs):
  # Training loop
  for epoch in range(num_epochs):
    state, loss, acc = train_step(state, graph, graph_labels, one_hot_labels, train_mask)
    val_loss, val_acc = compute_loss(state, state.params, graph, graph_labels, one_hot_labels, val_mask)
    print(f'step: {epoch:03d}, train loss: {loss:.4f}, train acc: {acc:.4f}, val loss: {val_loss:.4f}, val acc: {val_acc:.4f}')
  return state, acc, val_acc

image

display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
trained_model_state, train_acc, val_acc = train_model(model_state, graph, graph_labels, one_hot_labels, graph_train_mask, graph_val_mask, num_epochs=200)
accuracy_list.append(['Cora', 'train', float(train_acc)])
accuracy_list.append(['Cora', 'valid', float(val_acc)])
step: 000, train loss: 1.9462, train acc: 0.1209, val loss: 1.9386, val acc: 0.3100
step: 001, train loss: 1.9387, train acc: 0.2972, val loss: 1.9307, val acc: 0.3600
step: 002, train loss: 1.9309, train acc: 0.3295, val loss: 1.9215, val acc: 0.3720
step: 003, train loss: 1.9218, train acc: 0.3659, val loss: 1.9113, val acc: 0.3760
step: 004, train loss: 1.9118, train acc: 0.3651, val loss: 1.9013, val acc: 0.3600
step: 005, train loss: 1.9020, train acc: 0.3336, val loss: 1.8914, val acc: 0.3280
step: 006, train loss: 1.8924, train acc: 0.3030, val loss: 1.8811, val acc: 0.3180
step: 007, train loss: 1.8824, train acc: 0.2914, val loss: 1.8707, val acc: 0.3160
step: 008, train loss: 1.8723, train acc: 0.2881, val loss: 1.8601, val acc: 0.3160
step: 009, train loss: 1.8620, train acc: 0.2856, val loss: 1.8494, val acc: 0.3160
step: 010, train loss: 1.8516, train acc: 0.2848, val loss: 1.8388, val acc: 0.3160
...
...
...
step: 191, train loss: 0.2367, train acc: 0.9487, val loss: 0.4670, val acc: 0.8560
step: 192, train loss: 0.2349, train acc: 0.9495, val loss: 0.4665, val acc: 0.8540
step: 193, train loss: 0.2332, train acc: 0.9503, val loss: 0.4660, val acc: 0.8540
step: 194, train loss: 0.2315, train acc: 0.9503, val loss: 0.4655, val acc: 0.8540
step: 195, train loss: 0.2298, train acc: 0.9512, val loss: 0.4650, val acc: 0.8540
step: 196, train loss: 0.2281, train acc: 0.9512, val loss: 0.4646, val acc: 0.8540
step: 197, train loss: 0.2264, train acc: 0.9520, val loss: 0.4641, val acc: 0.8540
step: 198, train loss: 0.2248, train acc: 0.9520, val loss: 0.4637, val acc: 0.8540
step: 199, train loss: 0.2232, train acc: 0.9520, val loss: 0.4633, val acc: 0.8540

Testing

test_loss, test_acc = compute_loss(trained_model_state, trained_model_state.params, graph, graph_labels, one_hot_labels, graph_test_mask)
print(f'test loss: {test_loss:.4f}, test acc: {test_acc:.4f}')
accuracy_list.append(['Cora', 'test', float(test_acc)])
test loss: 0.4664, test acc: 0.8550

Visualizing nodes before and after training

nodes_untrained = model_state.apply_fn(model_state.params, graph).nodes
z_nodes_untrained = TSNE(n_components=2).fit_transform(nodes_untrained)

nodes_trained = trained_model_state.apply_fn(trained_model_state.params, graph).nodes
z_nodes_trained = TSNE(n_components=2).fit_transform(nodes_trained)
fig, axs = plt.subplots(1, 2, figsize=(12, 6))

axs[0].set_title('Before training')
axs[0].scatter(z_nodes_untrained[:,0], z_nodes_untrained[:,1], marker='o', color=np.array(plot_colors)[graph_labels])

axs[1].set_title('After training')
axs[1].scatter(z_nodes_trained[:,0], z_nodes_trained[:,1], marker='o', color=np.array(plot_colors)[graph_labels])

for ax in axs:
  ax.tick_params(axis='both',which='both',bottom=False,top=False,left=False,right=False,
            labelbottom=False,labeltop=False,labelleft=False,labelright=False);
  ax.set(xlabel=None, ylabel=None)
plt.show()

image

Import CiteSeer Dataset

You can find a description of this dataset in Pytorch-Geometric documentation. In case of split='full', all nodes except those in the validation and test sets will be used for training.

dataset = Planetoid(root='data/Planetoid', name='CiteSeer', split='full', transform=NormalizeFeatures())
num_features = dataset.num_features
num_classes = dataset.num_classes
data_CiteSeer = dataset[0]  # Get the first graph object.
data_CiteSeer
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.test.index
Processing...
Done!
Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])
graph = jraph.GraphsTuple(
      n_node=jnp.asarray([data_CiteSeer.x.shape[0]]),
      n_edge=jnp.asarray([data_CiteSeer.edge_index.shape[1]]),
      nodes=jnp.asarray(data_CiteSeer.x),
      # No edge features.
      edges=None,
      globals=None,
      senders=jnp.asarray([data_CiteSeer.edge_index[0,:]]).squeeze(),
      receivers=jnp.asarray([data_CiteSeer.edge_index[1,:]]).squeeze())

graph_train_mask = jnp.asarray([data_CiteSeer.train_mask]).squeeze()
graph_val_mask = jnp.asarray([data_CiteSeer.val_mask]).squeeze()
graph_test_mask = jnp.asarray([data_CiteSeer.test_mask]).squeeze()
graph_labels = jnp.asarray([data_CiteSeer.y]).squeeze()
one_hot_labels = jax.nn.one_hot(graph_labels, len(jnp.unique(graph_labels)))
model = GCN(8, len(jnp.unique(graph_labels)))
model
GCN(
    # attributes
    gcn1_output_dim = 8
    output_dim = 6
)
optimizer = optax.adam(learning_rate=0.01)

rng, inp_rng, init_rng = jax.random.split(jax.random.PRNGKey(random_seed), 3)
params = model.init(jax.random.PRNGKey(random_seed),graph)

model_state = train_state.TrainState.create(apply_fn=model.apply,
                                            params=params,
                                            tx=optimizer)
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
trained_model_state, train_acc, val_acc = train_model(model_state, graph, graph_labels, one_hot_labels, graph_train_mask, graph_val_mask, num_epochs=200)
accuracy_list.append(['CiteSeer', 'train', float(train_acc)])
accuracy_list.append(['CiteSeer', 'valid', float(val_acc)])
step: 000, train loss: 1.7918, train acc: 0.1932, val loss: 1.7849, val acc: 0.2080
step: 001, train loss: 1.7857, train acc: 0.1970, val loss: 1.7781, val acc: 0.2080
step: 002, train loss: 1.7796, train acc: 0.1970, val loss: 1.7709, val acc: 0.2080
step: 003, train loss: 1.7730, train acc: 0.1970, val loss: 1.7628, val acc: 0.2100
step: 004, train loss: 1.7654, train acc: 0.1987, val loss: 1.7548, val acc: 0.2160
step: 005, train loss: 1.7579, train acc: 0.2053, val loss: 1.7472, val acc: 0.2200
step: 006, train loss: 1.7507, train acc: 0.2080, val loss: 1.7400, val acc: 0.2180
step: 007, train loss: 1.7439, train acc: 0.2063, val loss: 1.7328, val acc: 0.2180
step: 008, train loss: 1.7370, train acc: 0.2063, val loss: 1.7255, val acc: 0.2180
step: 009, train loss: 1.7300, train acc: 0.2091, val loss: 1.7182, val acc: 0.2240
step: 010, train loss: 1.7230, train acc: 0.2195, val loss: 1.7109, val acc: 0.2420
...
...
...
step: 191, train loss: 0.3845, train acc: 0.8659, val loss: 0.7996, val acc: 0.7560
step: 192, train loss: 0.3827, train acc: 0.8664, val loss: 0.8007, val acc: 0.7560
step: 193, train loss: 0.3809, train acc: 0.8670, val loss: 0.8018, val acc: 0.7540
step: 194, train loss: 0.3791, train acc: 0.8681, val loss: 0.8029, val acc: 0.7540
step: 195, train loss: 0.3773, train acc: 0.8686, val loss: 0.8040, val acc: 0.7540
step: 196, train loss: 0.3756, train acc: 0.8686, val loss: 0.8051, val acc: 0.7560
step: 197, train loss: 0.3738, train acc: 0.8697, val loss: 0.8063, val acc: 0.7560
step: 198, train loss: 0.3721, train acc: 0.8714, val loss: 0.8074, val acc: 0.7580
step: 199, train loss: 0.3704, train acc: 0.8719, val loss: 0.8086, val acc: 0.7580
test_loss, test_acc = compute_loss(trained_model_state, trained_model_state.params, graph, graph_labels, one_hot_labels, graph_test_mask)
print(f'test loss: {test_loss:.4f}, test acc: {test_acc:.4f}')
accuracy_list.append(['CiteSeer', 'test', float(test_acc)])
test loss: 0.7778, test acc: 0.7680

Visualizing nodes before and after training

nodes_untrained = model_state.apply_fn(model_state.params, graph).nodes
z_nodes_untrained = TSNE(n_components=2).fit_transform(nodes_untrained)

nodes_trained = trained_model_state.apply_fn(trained_model_state.params, graph).nodes
z_nodes_trained = TSNE(n_components=2).fit_transform(nodes_trained)
fig, axs = plt.subplots(1, 2, figsize=(12, 6))

axs[0].set_title('Before training')
axs[0].scatter(z_nodes_untrained[:,0], z_nodes_untrained[:,1], marker='o', color=np.array(plot_colors)[graph_labels])

axs[1].set_title('After training')
axs[1].scatter(z_nodes_trained[:,0], z_nodes_trained[:,1], marker='o', color=np.array(plot_colors)[graph_labels])

for ax in axs:
  ax.tick_params(axis='both',which='both',bottom=False,top=False,left=False,right=False,
            labelbottom=False,labeltop=False,labelleft=False,labelright=False);
  ax.set(xlabel=None, ylabel=None)
plt.show()

image

Plotting the results

df = pd.DataFrame(accuracy_list, columns=('Dataset', 'Split', 'Accuracy'))
sns.barplot(df,x='Dataset', y='Accuracy', hue='Split', palette="muted")
plt.show()

image