Mashaan blog

Graph Attention Networks in DGL

Acknowledgment:

References:

@article{wang2019dgl,
 title    ={Deep Graph Library: A Graph-Centric, Highly-Performant Package for Graph Neural Networks},
 author   ={Minjie Wang and Da Zheng and Zihao Ye and Quan Gan and Mufei Li and Xiang Song and Jinjing Zhou and Chao Ma and Lingfan Yu and Yu Gai and Tianjun Xiao and Tong He and George Karypis and Jinyang Li and Zheng Zhang},
 year     ={2019},
 journal  ={arXiv preprint arXiv:1909.01315}
}
@article{velickovic2018graph,
 title    ="{Graph Attention Networks}",
author    ={Veli{\v{c}}kovi{\'{c}}, Petar and Cucurull, Guillem and Casanova, Arantxa and Romero, Adriana and Li{\`{o}}, Pietro and Bengio, Yoshua},
journal   ={International Conference on Learning Representations},
year      ={2018},
}

Import libraries

!pip install dgl
import numpy as np
from IPython.display import Javascript  # Restrict height of output cell.

import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl
import dgl.nn as dglnn
from dgl import AddSelfLoop
from dgl.data import CoraGraphDataset

GAT layers from DGL

class GAT(nn.Module):
    def __init__(self, in_size, hid_size, out_size, heads):
        super().__init__()
        self.gat_layers = nn.ModuleList()
        # two-layer GAT
        self.gat_layers.append(
            dglnn.GATConv(
                in_size,
                hid_size,
                heads[0],
                feat_drop=0.6,
                attn_drop=0.6,
                activation=F.elu,
            )
        )
        self.gat_layers.append(
            dglnn.GATConv(
                hid_size * heads[0],
                out_size,
                heads[1],
                feat_drop=0.6,
                attn_drop=0.6,
                activation=None,
            )
        )

    def forward(self, g, inputs):
        h = inputs
        for i, layer in enumerate(self.gat_layers):
            h = layer(g, h)
            if i == 1:  # last layer
                h = h.mean(1)
            else:  # other layer(s)
                h = h.flatten(1)

        return h

How GAT class processes data?

image

image

What happened in the $1^{st}$ GAT layer?

image

image

Evaluate and train functions

def evaluate(g, features, labels, mask, model):
    model.eval()
    with torch.no_grad():
        logits = model(g, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)
def train(g, features, labels, masks, model, epochs):
    # define train/val samples, loss function and optimizer
    train_mask = masks[0]
    val_mask = masks[1]
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4)

    # training loop
    for epoch in range(epochs):
        model.train()
        logits = model(g, features)
        loss = loss_fcn(logits[train_mask], labels[train_mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        acc = evaluate(g, features, labels, val_mask, model)
        print(
            "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
                epoch, loss.item(), acc
            )
        )

Import Cora Dataset

# load and preprocess dataset
transform = (AddSelfLoop())
data = CoraGraphDataset(transform=transform)
g = data[0]
Downloading /root/.dgl/cora_v2.zip from https://data.dgl.ai/dataset/cora_v2.zip...
/root/.dgl/cora_v2.zip: 100%
 132k/132k [00:00<00:00, 988kB/s]
Extracting file to /root/.dgl/cora_v2_d697a464
Finished data loading and preprocessing.
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done saving data into cached files.

Initialize the model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
g = g.int().to(device)
features = g.ndata["feat"]
labels = g.ndata["label"]
masks = g.ndata["train_mask"], g.ndata["val_mask"], g.ndata["test_mask"]

# create GAT model
in_size = features.shape[1]
out_size = data.num_classes
model = GAT(in_size, 8, out_size, heads=[8, 1]).to(device)

Training

# model training
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
print("Training...")
train(g, features, labels, masks, model, epochs=200)
Training...
Epoch 00000 | Loss 1.9459 | Accuracy 0.2700 
Epoch 00001 | Loss 1.9391 | Accuracy 0.2940 
Epoch 00002 | Loss 1.9244 | Accuracy 0.3300 
Epoch 00003 | Loss 1.9186 | Accuracy 0.3080 
Epoch 00004 | Loss 1.8923 | Accuracy 0.3780 
Epoch 00005 | Loss 1.9063 | Accuracy 0.4540 
Epoch 00006 | Loss 1.8913 | Accuracy 0.5700 
Epoch 00007 | Loss 1.8876 | Accuracy 0.6380 
Epoch 00008 | Loss 1.8661 | Accuracy 0.6860 
Epoch 00009 | Loss 1.8739 | Accuracy 0.7200 
Epoch 00010 | Loss 1.8432 | Accuracy 0.7120
...
...
...
Epoch 00191 | Loss 0.7450 | Accuracy 0.8140 
Epoch 00192 | Loss 0.6682 | Accuracy 0.8100 
Epoch 00193 | Loss 0.6534 | Accuracy 0.8120 
Epoch 00194 | Loss 0.7509 | Accuracy 0.8080 
Epoch 00195 | Loss 0.7447 | Accuracy 0.8080 
Epoch 00196 | Loss 0.7586 | Accuracy 0.8080 
Epoch 00197 | Loss 0.7324 | Accuracy 0.8080 
Epoch 00198 | Loss 0.6985 | Accuracy 0.8060 
Epoch 00199 | Loss 0.7049 | Accuracy 0.8080 

Testing

# test the model
print("Testing...")
acc = evaluate(g, features, labels, masks[2], model)
print("Test accuracy {:.4f}".format(acc))
Testing...
Test accuracy 0.8220