None
# code for loading the format for the notebook
import os
# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', 'notebook_format'))
from formats import load_style
load_style(css_style='custom2.css', plot_style=False)
os.chdir(path)
# 1. magic to print version
# 2. magic so that the notebook will reload external python modules
%load_ext watermark
%load_ext autoreload
%autoreload 2
import dgl
import torch
import dgl.data
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchmetrics.functional as MF
from time import perf_counter
from torchmetrics import Accuracy
from ogb.nodeproppred import DglNodePropPredDataset
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
%watermark -a "Ethen" -d -u -v -iv
In this particular notebook, we'll be:
LightningDataModule
, LightningModule
as well as Trainer
.A graph $\mathcal{G}(V, E)$ is a data structure containing a set of nodes (a.k.a vertices) $i \in V$ and a set of edges $e_{ij} \in E$ connecting vertices $i$ and $j$. Each node $i$ has an associated node features $x_i \in \mathbb{R}^d$ and labels $y_i$
A single Graph neural network (GNN) layer has three main steps that are performed on every node in the graph.
Message Passing:
GNN learns a node $i$ by examining nodes in its neighborhood $N_i$, where $N_i$ is defined as the set of nodes $j$ connected to the source node $i$ by an edge, more formally, $N_i = {j : e_{ij} \in E}$. When examining a node, we can use any arbitrary function, $F$, either a neural network like, MLP, or here, let's assume it will be a affine transformation:
\begin{align} \begin{aligned} F(x_j) = \mathbf{W}_j \cdot x_j + b \end{aligned} \end{align}Here, $\cdot$ represents matrix multiplication.
Aggregation:
Now that we have the messages from neighborhood node, we have to aggregate them somehow. Popular aggregation function includes sum/mean/max/min. The aggregation function, $G$, can be denoted as:
\begin{align} \begin{aligned} \bar{m}_i = G(\{F(x_j) : j \in \mathcal{N}_i\}) \end{aligned} \end{align}Update:
The GNN layer now has to update our source node $i$'s features and combine it with the incoming aggregated messages. For example, using addition.
\begin{align} \begin{aligned} h_i = \sigma(K(T(x_i) + \bar{m}_i))) \end{aligned} \end{align}Here, $T$, denotes a function that's applied to the source node $i$'s feature, $K$ denotes another transformation to project the blended features into another dimension, $\sigma$ denotes an activation function such as Relu. Notation-wise, the initial node features are called $x_i$, after a forward pass through a GNN layer, we denote the node features as $h_i$. If we were to have multiple GNN layers, then we denote node features as $h_i^l$, where $l$ is the current GNN layer index.
Note:
This code is largely ported from DGL's Pytorch lightning node classification example with additional explanations in between each section.
We will be using the ogbn-products dataset. Directly copying this dataset's description from its description page.
ogbn-products dataset is an undirected and unweighted graph, representing an Amazon product co-purchasing network. Nodes represent products sold in Amazon, and edges between two products indicate that the products are purchased together. Node features are generated by extracting bag-of-words features from the product descriptions followed by a Principal Component Analysis to reduce the dimension to 100.
The task is to predict the category of a product in a multi-class classification setup, where the 47 top-level categories are used for target labels.
dataset = DglNodePropPredDataset("ogbn-products")
graph, labels = dataset[0]
graph
For DGL graph, we can assign or extract node's features via our graph's ndata
attribute. Here, we assign dataset's label to our graph.
graph.ndata["label"] = labels.squeeze()
graph.ndata
# extract the train, validation and test split provided via the dataset
split_idx = dataset.get_idx_split()
train_idx, val_idx, test_idx = (
split_idx["train"],
split_idx["valid"],
split_idx["test"],
)
Given a graph as well as data splits, we can get our hands dirty and implement our data module.
class DataModule(LightningDataModule):
def __init__(
self, graph, train_idx, val_idx, fanouts, batch_size, n_classes, device
):
super().__init__()
sampler = dgl.dataloading.NeighborSampler(
fanouts, prefetch_node_feats=["feat"], prefetch_labels=["label"]
)
self.graph = graph
self.train_idx = train_idx
self.val_idx = val_idx
self.sampler = sampler
self.batch_size = batch_size
self.in_feats = graph.ndata["feat"].shape[1]
self.n_classes = n_classes
def train_dataloader(self):
return dgl.dataloading.DataLoader(
self.graph,
self.train_idx.to(device),
self.sampler,
device=device,
batch_size=self.batch_size,
shuffle=True,
drop_last=False,
num_workers=0,
use_uva=True
)
def val_dataloader(self):
return dgl.dataloading.DataLoader(
self.graph,
self.val_idx.to(device),
self.sampler,
device=device,
batch_size=self.batch_size,
shuffle=True,
drop_last=False,
num_workers=0,
use_uva=True,
)
Similar to general neural networks, we need a DataLoader
to sample batches of inputs. A data loader by default returns 3 elements: input_nodes
, output_nodes
, blocks
.
input_nodes
describe the nodes needed to compute the representation of output_nodes
. Whereas blocks
describe for each GNN layer, which node representations are to be computed as output, which node representations are needed as input, and how does representation from the input nodes propagate to the output nodes.
Each data loader also accepts an sampler, here we are using one called NeighborSampler
, which will make every node gather mesages from a fixed number of neighbors. We get to define the fanouts
parameter for the sampler which represents number of neighbors to sample for each GNN layer.
It also supports PyTorch concepts such as prefetching so model computation and data movement can happen in parallel, as well as a concept called UVA (unified virtual addressing), directly quoting from its documentation: This is when our graph is too large to fit onto GPU memory, and we let GPU perform sampling on graph that will be pinned on CPU memory.
fanouts = [15, 10, 5]
batch_size = 2
data_module = DataModule(graph, train_idx, val_idx, fanouts, batch_size, dataset.num_classes, device)
# sample output from the data loader
input_nodes, output_nodes, blocks = next(iter(data_module.train_dataloader()))
input_nodes, output_nodes, blocks
GraphSAGE stands for Graph SAmple and AggreGatE, its forward pass can be described with the following notation:
\begin{align} \begin{aligned} h_{N_i}^{(l+1)} &= \mathrm{aggregate}^{(l+1)} \left(\{h_{j}^{l}, \forall j \in N_i \}\right)\\h_{i}^{(l+1)} &= \sigma \left(W^{(l+1)} \cdot \mathrm{concat} (h_{i}^{l}, h_{N_i}^{l+1}) \right)\\h_{i}^{(l+1)} &= \mathrm{norm}(h_{i}^{(l+1)}) \end{aligned} \end{align}Hopefully each of these steps won't look that alien after covering the general pattern of GNN.
The way this works in DGL is if our features are stored in a graph object's ndata
, then from a sampled block object we can access source nodes' feature via srcdata
and destination nodes' feature via dstdata
. In the next few code chunks, we first perform a small demo where we access source node's features, feed it through a GNN layer, and check whether its shape matches output nodes' label size. After that we'll proceed with implementing our main model/module.
blocks[0].srcdata["feat"].shape
# out_feats is configurable, analogous to hidden layer's dimension size
sage_conv = dgl.nn.SAGEConv(
in_feats=data_module.in_feats,
out_feats=256,
aggregator_type="mean"
).to(device)
output = sage_conv(blocks[0], blocks[0].srcdata["feat"])
output.shape
blocks[0].dstdata["label"].shape
class SAGE(LightningModule):
"""Multi-layer GraphSAGE lightning module for node classification task."""
def __init__(self, in_feats: int, n_layers: int, n_hidden: int, n_classes: int, aggregator_type: str):
super().__init__()
self.save_hyperparameters()
self.layers = nn.ModuleList()
self.layers.append(dgl.nn.SAGEConv(in_feats, n_hidden, aggregator_type))
for i in range(1, n_layers - 1):
self.layers.append(dgl.nn.SAGEConv(n_hidden, n_hidden, aggregator_type))
self.layers.append(dgl.nn.SAGEConv(n_hidden, n_classes, aggregator_type))
self.dropout = nn.Dropout(0.5)
self.n_hidden = n_hidden
self.n_classes = n_classes
self.train_acc = Accuracy()
self.val_acc = Accuracy()
def forward(self, blocks, x):
h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)):
h = layer(block, h)
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
return h
def training_step(self, batch, batch_idx):
input_nodes, output_nodes, blocks = batch
x = blocks[0].srcdata["feat"]
y = blocks[-1].dstdata["label"]
y_hat = self(blocks, x)
loss = F.cross_entropy(y_hat, y)
self.train_acc(torch.argmax(y_hat, 1), y)
self.log(
"train_acc",
self.train_acc,
prog_bar=True,
on_step=True,
on_epoch=False
)
return loss
def validation_step(self, batch, batch_idx):
input_nodes, output_nodes, blocks = batch
x = blocks[0].srcdata["feat"]
y = blocks[-1].dstdata["label"]
y_hat = self(blocks, x)
self.val_acc(torch.argmax(y_hat, 1), y)
self.log(
"val_acc",
self.val_acc,
prog_bar=True,
on_step=True,
on_epoch=True,
sync_dist=True
)
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(),
lr=0.001,
weight_decay=5e-4
)
return optimizer
model = SAGE(
in_feats=data_module.in_feats,
n_layers=len(fanouts),
n_hidden=256,
n_classes=data_module.n_classes,
aggregator_type="mean"
)
model
The node chunk initiatizes the data as well as model module and kicks off model training through Trainer class .
n_hidden = 256
fanouts = [15, 10, 5]
aggregator_type = "mean"
batch_size = 1024
n_layers = len(fanouts)
data_module = DataModule(graph, train_idx, val_idx, fanouts, batch_size, dataset.num_classes, device)
model = SAGE(
in_feats=data_module.in_feats,
n_layers=n_layers,
n_hidden=n_hidden,
n_classes=data_module.n_classes,
aggregator_type=aggregator_type
)
checkpoint_callback = ModelCheckpoint(monitor="val_acc", save_top_k=1)
trainer = Trainer(
accelerator='gpu',
devices=[0],
max_epochs=10,
# note, we purpose-fully disabled the progress bar to prevent flooding our notebook's console
# in normal settings, we can/should definitely turn it on
enable_progress_bar=False,
log_every_n_steps=100,
callbacks=[checkpoint_callback]
)
t1_start = perf_counter()
trainer.fit(model, datamodule=data_module)
t1_stop = perf_counter()
print("Elapsed time:", t1_stop - t1_start)
The prediction/evaluation is also a bit interesting. As explained clearly by DGL Tutorial - Exact Offline Inference on Large Graphs While training our GNN, we often times perform neighborhood sampling for reducing memory. But while performing inferencing, it's better to truly aggregate over all neighbors.
The result of this is that our inference implemention will be slightly different compared to training. During training, we have an outer loop that's iterating over mini-batches of nodes (this is coming from our DataLoader), and an inner loop that's iterating over all our GNN's layer. During inferencing, what will happen is, we instead will have an outer loop that's iterating over the GNN layers, and an inner loop that's iterating over our mini-batches of nodes.
def predict(graph, model, batch_size, device):
graph.ndata["h"] = graph.ndata["feat"]
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
data_loader = dgl.dataloading.DataLoader(
graph,
torch.arange(graph.number_of_nodes()).to(device),
sampler,
batch_size=batch_size,
shuffle=False,
drop_last=False,
device=device,
num_workers=0,
use_uva=True
)
for l, layer in enumerate(model.layers):
y = torch.zeros(
graph.num_nodes(),
model.n_hidden if l != len(model.layers) - 1 else model.n_classes,
device='cpu'
)
for input_nodes, output_nodes, blocks in data_loader:
block = blocks[0]
x = block.srcdata['h']
h = layer(block, x)
if l != len(model.layers) - 1:
h = F.relu(h)
h = model.dropout(h)
y[output_nodes] = h.to('cpu')
graph.ndata["h"] = y
del graph.ndata['h']
return y
predict_batch_size = 4096
with torch.no_grad():
pred = predict(graph, model.to(device), predict_batch_size, device)
pred = pred[test_idx]
label = graph.ndata["label"][test_idx]
accuracy = MF.accuracy(pred, label)
accuracy = round(accuracy.item(), 3)
print("Test accuracy:", accuracy)
Hopefully, this served as a quick introduction to GNN's node classification task. Feel free to check the leaderboard for potential improvements to this baseline approach.