None gnn_node_classification_intro
In [1]:
# code for loading the format for the notebook
import os

# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', 'notebook_format'))

from formats import load_style
load_style(css_style='custom2.css', plot_style=False)
Out[1]:
In [2]:
os.chdir(path)

# 1. magic to print version
# 2. magic so that the notebook will reload external python modules
%load_ext watermark
%load_ext autoreload
%autoreload 2

import dgl
import torch
import dgl.data
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchmetrics.functional as MF
from time import perf_counter
from torchmetrics import Accuracy
from ogb.nodeproppred import DglNodePropPredDataset
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

%watermark -a "Ethen" -d -u -v -iv
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Author: Ethen

Last updated: 2022-10-22

Python implementation: CPython
Python version       : 3.8.10
IPython version      : 8.4.0

torch       : 1.10.0a0+git36449ea
dgl         : 0.9.1
torchmetrics: 0.10.0

Graph Neural Networks Node Classification Quick Introduction

In this particular notebook, we'll be:

  • Giving a quick introduction to Graph Neural Network.
  • Implement one particular algorithm GraphSAGE using DGL, deep graph library. We'll be introducing how to work with DGL library on the graph node classification task.
  • Use Pytorch Lightning for organizing our building blocks and training our model. This isn't a PyTorch lightning tutorial, readers are expected to understand its concept such as LightningDataModule, LightningModule as well as Trainer.

A graph $\mathcal{G}(V, E)$ is a data structure containing a set of nodes (a.k.a vertices) $i \in V$ and a set of edges $e_{ij} \in E$ connecting vertices $i$ and $j$. Each node $i$ has an associated node features $x_i \in \mathbb{R}^d$ and labels $y_i$

A single Graph neural network (GNN) layer has three main steps that are performed on every node in the graph.

  1. Message Passing
  2. Aggregation
  3. Update

Message Passing:

GNN learns a node $i$ by examining nodes in its neighborhood $N_i$, where $N_i$ is defined as the set of nodes $j$ connected to the source node $i$ by an edge, more formally, $N_i = {j : e_{ij} \in E}$. When examining a node, we can use any arbitrary function, $F$, either a neural network like, MLP, or here, let's assume it will be a affine transformation:

\begin{align} \begin{aligned} F(x_j) = \mathbf{W}_j \cdot x_j + b \end{aligned} \end{align}

Here, $\cdot$ represents matrix multiplication.

Aggregation:

Now that we have the messages from neighborhood node, we have to aggregate them somehow. Popular aggregation function includes sum/mean/max/min. The aggregation function, $G$, can be denoted as:

\begin{align} \begin{aligned} \bar{m}_i = G(\{F(x_j) : j \in \mathcal{N}_i\}) \end{aligned} \end{align}

Update:

The GNN layer now has to update our source node $i$'s features and combine it with the incoming aggregated messages. For example, using addition.

\begin{align} \begin{aligned} h_i = \sigma(K(T(x_i) + \bar{m}_i))) \end{aligned} \end{align}

Here, $T$, denotes a function that's applied to the source node $i$'s feature, $K$ denotes another transformation to project the blended features into another dimension, $\sigma$ denotes an activation function such as Relu. Notation-wise, the initial node features are called $x_i$, after a forward pass through a GNN layer, we denote the node features as $h_i$. If we were to have multiple GNN layers, then we denote node features as $h_i^l$, where $l$ is the current GNN layer index.

Note:

  • GNN aims to learn a function that generates embeddings via sampling and aggregation features from nodes' neighborhood. Innovations in GNN mainly involves changings to these three steps.
  • Number of layers in GNN is a hyperparameter that can be tweaked, the intuition is that $l^{th}$ GNN layer aggregate features from the $l^{th}$ hop neighborhood of node $i$. i.e. initially, the node sees its immediate neighbors and deeper into the network, it interacts with neighbors' neighbors and so on. Most GNN papers uses less than 4 layers to prevent the network from dying, where node embeddings all converge to similar representation after seeing nodes many hops away. This phenomenon becomes more prevalent for small and sparse graphs.

Implementation

This code is largely ported from DGL's Pytorch lightning node classification example with additional explanations in between each section.

We will be using the ogbn-products dataset. Directly copying this dataset's description from its description page.

ogbn-products dataset is an undirected and unweighted graph, representing an Amazon product co-purchasing network. Nodes represent products sold in Amazon, and edges between two products indicate that the products are purchased together. Node features are generated by extracting bag-of-words features from the product descriptions followed by a Principal Component Analysis to reduce the dimension to 100.

The task is to predict the category of a product in a multi-class classification setup, where the 47 top-level categories are used for target labels.

In [3]:
dataset = DglNodePropPredDataset("ogbn-products")
graph, labels = dataset[0]
graph
Out[3]:
Graph(num_nodes=2449029, num_edges=123718280,
      ndata_schemes={'feat': Scheme(shape=(100,), dtype=torch.float32)}
      edata_schemes={})

For DGL graph, we can assign or extract node's features via our graph's ndata attribute. Here, we assign dataset's label to our graph.

In [4]:
graph.ndata["label"] = labels.squeeze()
graph.ndata
Out[4]:
{'feat': tensor([[ 0.0319, -0.1959,  0.0520,  ...,  0.0767, -0.3930, -0.0648],
        [-0.0241,  0.6303,  1.0606,  ..., -1.6875,  3.5867,  0.8182],
        [ 0.3327, -0.5586, -0.2886,  ..., -0.3716,  0.2521,  0.0415],
        ...,
        [ 0.1066,  0.2655, -0.0057,  ...,  1.0867,  0.0759, -1.1737],
        [ 0.2497, -0.2574,  0.4123,  ...,  1.5466,  1.0310, -0.2966],
        [ 0.7175, -0.2393,  0.0443,  ..., -1.0132, -0.4141, -0.0823]]), 'label': tensor([0, 1, 2,  ..., 8, 2, 4])}
In [5]:
# extract the train, validation and test split provided via the dataset
split_idx = dataset.get_idx_split()
train_idx, val_idx, test_idx = (
    split_idx["train"],
    split_idx["valid"],
    split_idx["test"],
)

Data Module

Given a graph as well as data splits, we can get our hands dirty and implement our data module.

In [6]:
class DataModule(LightningDataModule):

    def __init__(
        self, graph, train_idx, val_idx, fanouts, batch_size, n_classes, device
    ):
        super().__init__()

        sampler = dgl.dataloading.NeighborSampler(
            fanouts, prefetch_node_feats=["feat"], prefetch_labels=["label"]
        )

        self.graph = graph
        self.train_idx = train_idx
        self.val_idx = val_idx
        self.sampler = sampler
        self.batch_size = batch_size
        self.in_feats = graph.ndata["feat"].shape[1]
        self.n_classes = n_classes

    def train_dataloader(self):
        return dgl.dataloading.DataLoader(
            self.graph,
            self.train_idx.to(device),
            self.sampler,
            device=device,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=False,
            num_workers=0,
            use_uva=True
        )

    def val_dataloader(self):
        return dgl.dataloading.DataLoader(
            self.graph,
            self.val_idx.to(device),
            self.sampler,
            device=device,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=False,
            num_workers=0,
            use_uva=True,
        )

Similar to general neural networks, we need a DataLoader to sample batches of inputs. A data loader by default returns 3 elements: input_nodes, output_nodes, blocks.

input_nodes describe the nodes needed to compute the representation of output_nodes. Whereas blocks describe for each GNN layer, which node representations are to be computed as output, which node representations are needed as input, and how does representation from the input nodes propagate to the output nodes.

Each data loader also accepts an sampler, here we are using one called NeighborSampler, which will make every node gather mesages from a fixed number of neighbors. We get to define the fanouts parameter for the sampler which represents number of neighbors to sample for each GNN layer.

It also supports PyTorch concepts such as prefetching so model computation and data movement can happen in parallel, as well as a concept called UVA (unified virtual addressing), directly quoting from its documentation: This is when our graph is too large to fit onto GPU memory, and we let GPU perform sampling on graph that will be pinned on CPU memory.

In [7]:
fanouts = [15, 10, 5]
batch_size = 2
data_module = DataModule(graph, train_idx, val_idx, fanouts, batch_size, dataset.num_classes, device)

# sample output from the data loader
input_nodes, output_nodes, blocks = next(iter(data_module.train_dataloader()))
input_nodes, output_nodes, blocks
Out[7]:
(tensor([106546,  74635, 187662,  ..., 232149, 215355,  80118], device='cuda:0'),
 tensor([106546,  74635], device='cuda:0'),
 [Block(num_src_nodes=1428, num_dst_nodes=126, num_edges=1847),
  Block(num_src_nodes=126, num_dst_nodes=12, num_edges=120),
  Block(num_src_nodes=12, num_dst_nodes=2, num_edges=10)])

GraphSAGE Model

GraphSAGE stands for Graph SAmple and AggreGatE, its forward pass can be described with the following notation:

\begin{align} \begin{aligned} h_{N_i}^{(l+1)} &= \mathrm{aggregate}^{(l+1)} \left(\{h_{j}^{l}, \forall j \in N_i \}\right)\\h_{i}^{(l+1)} &= \sigma \left(W^{(l+1)} \cdot \mathrm{concat} (h_{i}^{l}, h_{N_i}^{l+1}) \right)\\h_{i}^{(l+1)} &= \mathrm{norm}(h_{i}^{(l+1)}) \end{aligned} \end{align}

Hopefully each of these steps won't look that alien after covering the general pattern of GNN.

  • For each node, it aggregates feature representation from its immediate neighborhood, which can be uniformly sampled. The original paper uses aggregation function such as mean, pooling, LSTM.
  • After aggregating neighboring feature representations, it then concatenates it with the node's current representation. This concatenation is then fed through a fully connected layer with nonlinear activation function.
  • The last step is normalizing learned embedding to unit length.

The way this works in DGL is if our features are stored in a graph object's ndata, then from a sampled block object we can access source nodes' feature via srcdata and destination nodes' feature via dstdata. In the next few code chunks, we first perform a small demo where we access source node's features, feed it through a GNN layer, and check whether its shape matches output nodes' label size. After that we'll proceed with implementing our main model/module.

In [8]:
blocks[0].srcdata["feat"].shape
Out[8]:
torch.Size([1428, 100])
In [9]:
# out_feats is configurable, analogous to hidden layer's dimension size
sage_conv = dgl.nn.SAGEConv(
    in_feats=data_module.in_feats,
    out_feats=256,
    aggregator_type="mean"
).to(device)
output = sage_conv(blocks[0], blocks[0].srcdata["feat"])
output.shape
Out[9]:
torch.Size([126, 256])
In [10]:
blocks[0].dstdata["label"].shape
Out[10]:
torch.Size([126])
In [11]:
class SAGE(LightningModule):
    """Multi-layer GraphSAGE lightning module for node classification task."""

    def __init__(self, in_feats: int, n_layers: int, n_hidden: int, n_classes: int, aggregator_type: str):
        super().__init__()
        self.save_hyperparameters()

        self.layers = nn.ModuleList()
        self.layers.append(dgl.nn.SAGEConv(in_feats, n_hidden, aggregator_type))
        for i in range(1, n_layers - 1):
            self.layers.append(dgl.nn.SAGEConv(n_hidden, n_hidden, aggregator_type))

        self.layers.append(dgl.nn.SAGEConv(n_hidden, n_classes, aggregator_type))

        self.dropout = nn.Dropout(0.5)
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.train_acc = Accuracy()
        self.val_acc = Accuracy()

    def forward(self, blocks, x):
        h = x
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h = layer(block, h)
            if l != len(self.layers) - 1:
                h = F.relu(h)
                h = self.dropout(h)
        return h
        
    def training_step(self, batch, batch_idx):
        input_nodes, output_nodes, blocks = batch
        x = blocks[0].srcdata["feat"]
        y = blocks[-1].dstdata["label"]
        y_hat = self(blocks, x)
        loss = F.cross_entropy(y_hat, y)
        self.train_acc(torch.argmax(y_hat, 1), y)
        self.log(
            "train_acc",
            self.train_acc,
            prog_bar=True,
            on_step=True,
            on_epoch=False
        )
        return loss

    def validation_step(self, batch, batch_idx):
        input_nodes, output_nodes, blocks = batch
        x = blocks[0].srcdata["feat"]
        y = blocks[-1].dstdata["label"]
        y_hat = self(blocks, x)
        self.val_acc(torch.argmax(y_hat, 1), y)
        self.log(
            "val_acc",
            self.val_acc,
            prog_bar=True,
            on_step=True,
            on_epoch=True,
            sync_dist=True
        )

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=0.001,
            weight_decay=5e-4
        )
        return optimizer
In [12]:
model = SAGE(
    in_feats=data_module.in_feats,
    n_layers=len(fanouts),
    n_hidden=256,
    n_classes=data_module.n_classes,
    aggregator_type="mean"
)
model
Out[12]:
SAGE(
  (layers): ModuleList(
    (0): SAGEConv(
      (feat_drop): Dropout(p=0.0, inplace=False)
      (fc_self): Linear(in_features=100, out_features=256, bias=False)
      (fc_neigh): Linear(in_features=100, out_features=256, bias=False)
    )
    (1): SAGEConv(
      (feat_drop): Dropout(p=0.0, inplace=False)
      (fc_self): Linear(in_features=256, out_features=256, bias=False)
      (fc_neigh): Linear(in_features=256, out_features=256, bias=False)
    )
    (2): SAGEConv(
      (feat_drop): Dropout(p=0.0, inplace=False)
      (fc_self): Linear(in_features=256, out_features=47, bias=False)
      (fc_neigh): Linear(in_features=256, out_features=47, bias=False)
    )
  )
  (dropout): Dropout(p=0.5, inplace=False)
  (train_acc): Accuracy()
  (val_acc): Accuracy()
)

Trainer

The node chunk initiatizes the data as well as model module and kicks off model training through Trainer class .

In [15]:
n_hidden = 256
fanouts = [15, 10, 5]
aggregator_type = "mean"
batch_size = 1024
n_layers = len(fanouts)

data_module = DataModule(graph, train_idx, val_idx, fanouts, batch_size, dataset.num_classes, device)
model = SAGE(
    in_feats=data_module.in_feats,
    n_layers=n_layers,
    n_hidden=n_hidden,
    n_classes=data_module.n_classes,
    aggregator_type=aggregator_type
)

checkpoint_callback = ModelCheckpoint(monitor="val_acc", save_top_k=1)
trainer = Trainer(
    accelerator='gpu',
    devices=[0],
    max_epochs=10,
    # note, we purpose-fully disabled the progress bar to prevent flooding our notebook's console
    # in normal settings, we can/should definitely turn it on
    enable_progress_bar=False,
    log_every_n_steps=100,
    callbacks=[checkpoint_callback]
)
t1_start = perf_counter()
trainer.fit(model, datamodule=data_module)
t1_stop = perf_counter()
print("Elapsed time:", t1_stop - t1_start)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type       | Params
-----------------------------------------
0 | layers    | ModuleList | 206 K 
1 | dropout   | Dropout    | 0     
2 | train_acc | Accuracy   | 0     
3 | val_acc   | Accuracy   | 0     
-----------------------------------------
206 K     Trainable params
0         Non-trainable params
206 K     Total params
0.828     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=10` reached.
Elapsed time: 117.55090914410539

Evaluation

The prediction/evaluation is also a bit interesting. As explained clearly by DGL Tutorial - Exact Offline Inference on Large Graphs While training our GNN, we often times perform neighborhood sampling for reducing memory. But while performing inferencing, it's better to truly aggregate over all neighbors.

The result of this is that our inference implemention will be slightly different compared to training. During training, we have an outer loop that's iterating over mini-batches of nodes (this is coming from our DataLoader), and an inner loop that's iterating over all our GNN's layer. During inferencing, what will happen is, we instead will have an outer loop that's iterating over the GNN layers, and an inner loop that's iterating over our mini-batches of nodes.

In [16]:
def predict(graph, model, batch_size, device):
    graph.ndata["h"] = graph.ndata["feat"]
    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
    data_loader = dgl.dataloading.DataLoader(
        graph,
        torch.arange(graph.number_of_nodes()).to(device),
        sampler,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        device=device,
        num_workers=0,
        use_uva=True
    )
    
    for l, layer in enumerate(model.layers):
        y = torch.zeros(
            graph.num_nodes(),
            model.n_hidden if l != len(model.layers) - 1 else model.n_classes,
            device='cpu'
        )
        for input_nodes, output_nodes, blocks in data_loader:
            block = blocks[0]
            x = block.srcdata['h']
            h = layer(block, x)
            if l != len(model.layers) - 1:
                h = F.relu(h)
                h = model.dropout(h)

            y[output_nodes] = h.to('cpu')

        graph.ndata["h"] = y

    del graph.ndata['h']
    return y
In [17]:
predict_batch_size = 4096

with torch.no_grad():
    pred = predict(graph, model.to(device), predict_batch_size, device)
    pred = pred[test_idx]
    label = graph.ndata["label"][test_idx]
    accuracy = MF.accuracy(pred, label)
    accuracy = round(accuracy.item(), 3)

print("Test accuracy:", accuracy)
Test accuracy: 0.748

Hopefully, this served as a quick introduction to GNN's node classification task. Feel free to check the leaderboard for potential improvements to this baseline approach.

Reference