None
# code for loading notebook's format
import os
# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', 'notebook_format'))
from formats import load_style
load_style(css_style='custom2.css', plot_style=False)
os.chdir(path)
%load_ext watermark
%load_ext autoreload
%autoreload 2
import os
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.nn import functional as F
from datasets import Dataset
from torch.utils.data import DataLoader
from sklearn.preprocessing import LabelEncoder
from transformers import (
PretrainedConfig,
PreTrainedModel,
TrainingArguments,
Trainer
)
import torchmetrics
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
%watermark -a 'Ethen' -d -t -v -u -p torch,sklearn,numpy,pandas,torchmetrics,datasets,transformers
Suppose we have a query denoted as $q$, and its corresponding $n$ set of documents denoted as $D = {d_1, d_2, ..., d_n}$. Our objective is to learn a function $f$ such that $f(q, D)$ will produce an ordered collection of documents, $D^*$, in descending order of relevance. Where the exact definition of relevance can vary between different applications.
In general, there're three main types of loss function for training this function: pointwise, pairwise, listwise. In this article, we'll be giving a 101 introduction to each of these variants, list out their pros and cons, as well as implementing these loss functions ourselves and training the tabular deep learning module using huggingface Trainer.
Pointwise
For pointwise approach, the aforementioned ranking task is formulated as a classic regression or classification task. The function $f(q, D)$ is simplied to $f(q, d_i)$, treating the relevance assessment of each query document independently. Suppose we have two queries that yield 2 and 3 corresponding documents respectively:
\begin{align} q_1 & \rightarrow d_1, d_2 \nonumber \\ q_2 & \rightarrow d_3, d_4, d_5 \end{align}The training examples $x_i$ are creating by pairing each query with its associated documents.
\begin{align} x_1: q_1, d_1 \nonumber \\ x_2: q_1, d_2 \nonumber \\ x_3: q_2, d_3 \nonumber \\ x_4: q_2, d_4 \nonumber \\ x_5: q_2, d_5 \end{align}Pros:
Cons:
Pairwise
In pairwise approach, the goal remains identical to pointwise, in which we're learning a pointwise scoring function $f(q, d_i)$, but training instances are constructed using pairs of documents from the same query:
\begin{align} x_1: q_1, (d_1, d_2) \nonumber \\ x_2: q_2, (d_3, d_4) \nonumber \\ x_3: q_2, (d_3, d_5) \nonumber \\ x_4: q_2, (d_4, d_5) \end{align}This approach introduces a new set of binary pairwise labels, derived by comparing individual relevance scores within each pair. For instance, considering the first query $q_1$, if $y_1 = 0$ (totally irrelevant) for $d_1$ and $y_2 = 3$ (highly relevant) for $d_2$, a new label $y_1 < y_2$ is assigned to the document pair $(d_1, d_2)$. This transforms the task into a binary classification learning problem.
To learn the pointwise function $f(q, d_i)$ in a pairwise manner, RankNet [1] proposed modeling the score difference probabilistically using logistic function:
\begin{align} Pr(i \succ j) = \frac{1}{1 + exp^{-(s_i - s_j)}} \end{align}Where if document $i$ is deemed a better match than document $j$ ($i \succ j$), the probability of the scoring function assigning a higher score to $f(q, d_i) = s_i$ than $f(q, d_j) = s_j$ should be close to 1. This reflects the model's effort to understand how to score document pairs based on query information, effectively learning to rank.
Pros:
Cons:
Listwise
Listwise approach addresses the ranking problem in its natural form, specifically it takes in a list of instances during training so the group structure is maintained.
\begin{align} x_1&: q_1, (d_1, d_2) \nonumber \\ x_2&: q_2, (d_3, d_4, d_5) \end{align}One of the first proposed approach is ListNet [2], where the loss is calculated between a predicted probability distribution versus target probability distribution.
\begin{align} P_{\boldsymbol{y}}\left(x_i\right)=\frac{y_i}{\sum_{j=1}^n y_j} \\ P_f\left(x_i\right)=\frac{e^{f(\boldsymbol{x_i})}}{\sum_{j=1}^n e^{f(\boldsymbol{x_j})}} \end{align}Where:
Pros:
Cons:
Note that:
#!wget https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.rar
#!unrar x MQ2008.rar
# show case some sample raw data
input_path = 'MQ2008/Fold1/train.txt'
with open(input_path) as f:
for _ in range(2):
line = f.readline()
print(line)
Each row represents a query-document pair in the dataset, with columns structured as follows:
def parse_raw_data(input_path):
labels = []
query_ids = []
features = []
with open(input_path) as f:
for line in f:
# filter out comment about the record
if "#" in line:
line = line[:line.index("#")]
splitted_line = line.strip().split(" ")
label = int(splitted_line[0])
labels.append(label)
query_id = splitted_line[1]
query_ids.append(query_id)
feature = [float(feature_str.split(':')[1]) for feature_str in splitted_line[2:]]
features.append(feature)
df = pd.DataFrame(features)
df["context"] = query_ids
df["label"] = labels
return df
# concatenate data under different cross validation folds together
input_paths = [
'MQ2008/Fold1/train.txt',
'MQ2008/Fold2/train.txt',
'MQ2008/Fold3/train.txt',
'MQ2008/Fold4/train.txt',
'MQ2008/Fold5/train.txt'
]
df_list = []
for input_path in input_paths:
df = parse_raw_data(input_path)
df_list.append(df)
df_train = pd.concat(df_list, ignore_index=True)
df_train["split"] = "train"
print(df_train.shape)
df_train.head()
input_paths = [
'MQ2008/Fold1/vali.txt',
'MQ2008/Fold2/vali.txt',
'MQ2008/Fold3/vali.txt',
'MQ2008/Fold4/vali.txt',
'MQ2008/Fold5/vali.txt'
]
df_list = []
for input_path in input_paths:
df = parse_raw_data(input_path)
df_list.append(df)
df_validation = pd.concat(df_list, ignore_index=True)
df_validation["split"] = "validation"
print(df_validation.shape)
df_validation.head()
# convert context column to numerical indices, PyTorch doesn't take in string field
label_encoder = LabelEncoder()
df = pd.concat([df_train, df_validation], ignore_index=True)
df["context"] = label_encoder.fit_transform(df["context"])
# we can experiment with binary relevance label or the default graded relevance label
# df.loc[df["label"] == 2, "label"] = 1
print(df.shape)
df["label"].value_counts()
df_train = df[df["split"] == "train"].drop(columns=["split"]).reset_index(drop=True)
df_validation = df[df["split"] == "validation"].drop(columns=["split"]).reset_index(drop=True)
dataset_train = Dataset.from_pandas(df_train)
dataset_validation = Dataset.from_pandas(df_validation)
dataset_train
# example feature config
# tabular_features_config = {
# "0": {
# "dtype": "numerical"
# }
# }
columns = [i for i in range(46)]
tabular_features_config = {str(i): {"dtype": "numerical"} for i in columns}
def tabular_collate_fn(batch):
"""
Use in conjunction with Dataloader's collate_fn for tabular data.
Returns
-------
batch : dict
Dictionary with three keys: tabular_inputs, contexts, and labels. Tabular
inputs is a nested field, where each element is a feature_name -> float tensor
mapping. Contexts defines examples that share the same context/query. e.g.
{
'tabular_inputs': {'I1': tensor([0., 0.]), 'C1': tensor([ 888., 1313.])},
'labels': tensor([0, 0]),
'contexts': tensor([0, 0])
}
"""
labels = []
contexts = []
tabular_inputs = {}
for example in batch:
label = example["label"]
labels.append(label)
context = example["context"]
contexts.append(context)
for name in tabular_features_config:
feature = example[name]
if name not in tabular_inputs:
tabular_inputs[name] = [feature]
else:
tabular_inputs[name].append(feature)
for name in tabular_inputs:
tabular_inputs[name] = torch.FloatTensor(tabular_inputs[name])
batch = {
"tabular_inputs": tabular_inputs,
"labels": torch.LongTensor(labels),
"contexts": torch.LongTensor(contexts)
}
return batch
data_loader = DataLoader(dataset_train, batch_size=2, collate_fn=tabular_collate_fn)
batch = next(iter(data_loader))
batch
Learning to rank based approaches regardless of whether it's pairwise or listwise requires data from the same context/group to be in the same mini-batch. Given our input data is in a pointwise format where each row represents a query-document pair, some additional transformations are necessary. We'll use some toy examples to illustrate these points before showing the full blown implementation.
In pairwise loss, the trick is to expand the original 1d tensor for computing a pairwise difference.
pred_tensor = torch.FloatTensor([1, 2, 3, 1, 2])
target_tensor = torch.Tensor([0, 1, 1, 1, 0])
context_tensor = torch.Tensor([0, 0, 0, 1, 1])
# 6 total positive pairs
target_pairwise_diff = target_tensor.unsqueeze(0) - target_tensor.unsqueeze(1)
print("label:\n ", target_pairwise_diff)
# context pairs, shows pairs from the first 3 examples belonging to the same context 0,
# and pairs from example 4, 5 belongs to the same context 1
context_pairwise_diff = context_tensor.unsqueeze(0) == context_tensor.unsqueeze(1)
print("context:\n ", context_pairwise_diff)
In listwise loss, loss are calculated once for all data within the same context/group. Hence apart from the predicted scores and target/labels, we also need to know which examples belong to the same context/group. One common way to do this is to assume the examples are already sorted by context, and have a group length variable which stores each group's instance count.
In the example below, we have 5 observations belonging to 2 contexts/groups. [3, 2]
means the fist 3 items belongs to the first group, whereas the next 2 items belongs to the second group. torch.split
then splits the original single tensor into grouped chunks, in a vanilla implementation, we can loop through each group and compute the cross entropy loss.
pred_tensor = torch.FloatTensor([1, 2, 3, 1, 2])
target_tensor = torch.Tensor([0, 1, 1, 1, 0])
group_length_tensor = torch.LongTensor([3, 2])
losses = []
for pred, target in zip(
torch.split(pred_tensor, group_length_tensor.tolist()),
torch.split(target_tensor, group_length_tensor.tolist())
):
# equivalent to cross entropy
# loss = -torch.dot(target, F.log_softmax(pred, dim=0))
loss = F.cross_entropy(pred, target)
losses.append(loss)
loop_listwise_loss = torch.stack(losses)
loop_listwise_loss
A cleaner solution would be to pad these grouped chunks and perform the calculation in a batched manner. The padding values do matter, where we'll use an extremely small prediction score with 0 as its corresponding label.
pred_group = torch.split(pred_tensor, group_length_tensor.tolist())
target_group = torch.split(target_tensor, group_length_tensor.tolist())
pred_pad = pad_sequence(pred_group, batch_first=True, padding_value=-1e4)
target_pad = pad_sequence(target_group, batch_first=True, padding_value=0.0)
print(pred_pad)
print(target_pad)
loss_fct = nn.CrossEntropyLoss(reduction="none")
batch_listwise_loss = loss_fct(pred_pad, target_pad)
print(batch_listwise_loss)
assert torch.equal(loop_listwise_loss, batch_listwise_loss)
def compute_pairwise_loss(logits, labels, contexts):
logits_positive = logits[:, 0]
logits_positive_diff = logits_positive.unsqueeze(0) - logits_positive.unsqueeze(1)
labels_pairwise_diff = labels.unsqueeze(0) - labels.unsqueeze(1)
labels_positive_mask = labels_pairwise_diff > 0
context_pairwise_diff = contexts.unsqueeze(0) == contexts.unsqueeze(1)
loss_fct = nn.LogSigmoid()
logits_positive_masked = torch.masked_select(logits_positive_diff, labels_positive_mask * context_pairwise_diff)
if len(logits_positive_masked) == 0:
pairwise_loss = torch.tensor(0.0, requires_grad=True).to(logits.device)
else:
pairwise_loss = -loss_fct(logits_positive_masked).mean()
return pairwise_loss
def compute_listwise_loss(logits, labels, contexts):
sorted_contexts, sorted_indices = torch.sort(contexts)
sorted_labels = labels[sorted_indices]
logits_positive = logits[:, 0]
sorted_logits_positive = logits_positive[sorted_indices]
# contexts should already be sorted, using unique_consecutive as opposed to
# unique for avoiding additional sorting
unique_contexts, group_length = torch.unique_consecutive(contexts, return_counts=True)
logits_positive_group = torch.split(sorted_logits_positive, group_length.tolist())
labels_group = torch.split(sorted_labels, group_length.tolist())
# for logits, pad with an extremely small prediction score, this default value works even
# when using float16 mixed precision training
logits_pad = pad_sequence(logits_positive_group, batch_first=True, padding_value=-1e+4)
labels_pad = pad_sequence(labels_group, batch_first=True, padding_value=0.0)
# we ensure there are more than 1 examples and at least 1 positive
# examples per group/context
group_mask = (group_length > 1) & (labels_pad.sum(dim=1) > 0)
logits_pad = logits_pad[group_mask]
labels_pad = labels_pad[group_mask]
if len(logits_pad) > 0:
loss_fct = nn.CrossEntropyLoss(reduction="mean")
listwise_loss = loss_fct(logits_pad, labels_pad.float())
else:
listwise_loss = torch.tensor(0.0, requires_grad=True).to(logits.device)
return listwise_loss
def get_mlp_layers(input_dim: int, mlp_config):
"""
Construct MLP, a.k.a. Feed forward layers based on input config.
Parameters
----------
input_dim :
Input dimension for the first layer.
mlp_config : list of dictionary with mlp spec.
An example is shown below, the only mandatory parameter is hidden size.
```
[
{
"hidden_size": 1024,
"dropout_p": 0.1,
"activation_function": "ReLU",
"activation_function_kwargs": {},
"normalization_function": "LayerNorm"
"normalization_function_kwargs": {"eps": 1e-05}
}
]
```
Returns
-------
nn.Sequential :
Sequential layer converted from input mlp_config. If mlp_config
is None, then this returned value will also be None.
current_dim :
Dimension for the last layer.
"""
if mlp_config is None:
return None, input_dim
layers = []
current_dim = input_dim
for config in mlp_config:
hidden_size = config["hidden_size"]
dropout_p = config.get("dropout_p", 0.0)
activation_function = config.get("activation_function")
activation_function_kwargs = config.get("activation_function_kwargs", {})
normalization_function = config.get("normalization_function")
normalization_function_kwargs = config.get("normalization_function_kwargs", {})
linear = nn.Linear(current_dim, hidden_size)
layers.append(linear)
if normalization_function:
normalization = getattr(nn, normalization_function)(hidden_size, **normalization_function_kwargs)
layers.append(normalization)
if activation_function:
activation = getattr(nn, activation_function)(**activation_function_kwargs)
layers.append(activation)
dropout = nn.Dropout(p=dropout_p)
layers.append(dropout)
current_dim = hidden_size
return nn.Sequential(*layers), current_dim
class TabularModelConfig(PretrainedConfig):
def __init__(
self,
tabular_features_config=None,
mlp_config=None,
num_labels=1,
loss_name="listwise",
**kwargs
):
super().__init__(**kwargs)
self.tabular_features_config = tabular_features_config
self.mlp_config = mlp_config
self.num_labels = num_labels
self.loss_name = loss_name
class TabularModel(PreTrainedModel):
config_class = TabularModelConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.embeddings, output_dim = self.init_tabular_parameters(config.tabular_features_config)
self.mlp, output_dim = get_mlp_layers(output_dim, config.mlp_config)
self.head = nn.Linear(output_dim, config.num_labels)
if config.loss_name == "pairwise":
self.loss_function = compute_pairwise_loss
else:
self.loss_function = compute_listwise_loss
def forward(self, tabular_inputs, contexts, labels=None):
concatenated_inputs = self.concatenate_tabular_inputs(
tabular_inputs,
self.config.tabular_features_config
)
mlp_outputs = self.mlp(concatenated_inputs)
logits = self.head(mlp_outputs)
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, contexts)
return loss, logits, contexts
def concatenate_tabular_inputs(self, inputs, tabular_features_config):
numerical_inputs = []
categorical_inputs = []
for name, config in tabular_features_config.items():
if config["dtype"] == "categorical":
feature_name = f"{name}_embedding"
share_embedding = config.get("share_embedding")
if share_embedding:
feature_name = f"{share_embedding}_embedding"
embedding = self.embeddings[feature_name]
features = inputs[name].type(torch.long)
embed = embedding(features)
categorical_inputs.append(embed)
elif config["dtype"] == "numerical":
features = inputs[name].type(torch.float32)
if len(features.shape) == 1:
features = features.unsqueeze(dim=1)
numerical_inputs.append(features)
if len(numerical_inputs) > 0:
numerical_inputs = torch.cat(numerical_inputs, dim=-1)
categorical_inputs.append(numerical_inputs)
concatenated_inputs = torch.cat(categorical_inputs, dim=-1)
return concatenated_inputs
def init_tabular_parameters(self, tabular_features_config):
embeddings = {}
output_dim = 0
for name, config in tabular_features_config.items():
if config["dtype"] == "categorical":
feature_name = f"{name}_embedding"
# create new embedding layer for categorical features if share_embedding is None
share_embedding = config.get("share_embedding")
if share_embedding:
share_embedding_config = model.pairwise_features_info[share_embedding]
embedding_size = share_embedding_config["embedding_size"]
else:
embedding_size = config["embedding_size"]
embedding = nn.Embedding(config["vocab_size"], embedding_size)
embeddings[feature_name] = embedding
output_dim += embedding_size
elif config["dtype"] == "numerical":
output_dim += 1
return nn.ModuleDict(embeddings), output_dim
mlp_config = [
{
"hidden_size": 1024,
"dropout_p": 0.1,
"activation_function": "ReLU",
"normalization_function": "LayerNorm"
},
{
"hidden_size": 512,
"dropout_p": 0.1,
"activation_function": "ReLU",
"normalization_function": "LayerNorm"
},
{
"hidden_size": 256,
"dropout_p": 0.1,
"activation_function": "ReLU",
"normalization_function": "LayerNorm"
}
]
config = TabularModelConfig(tabular_features_config, mlp_config, loss_name="listwise")
model = TabularModel(config).to(device)
print("# of parameters: ", model.num_parameters())
model
def compute_metrics(eval_preds, round_digits: int = 3):
"""Reports NDCG metrics"""
(y_pred, context), y_true = eval_preds
y_score = y_pred[:, 0]
ndcg_metrics = torchmetrics.retrieval.RetrievalNormalizedDCG(top_k=5)
ndcg = ndcg_metrics(torch.FloatTensor(y_score), torch.FloatTensor(y_true), indexes=torch.LongTensor(context))
return {
'ndcg': ndcg
}
When training a learning to rank model, an important detail is to prevent data shuffling in our data loader so data from the same context can be grouped together in a mini-batch. At the time of writing this, huggingface transformer's Trainer will by default enable shuffling on our train dataset. We quickly override that behaviour by using get_test_dataloader
even for our train dataloader. This addresses the issue with the least amount of code with the quirk being now per_device_eval_batch_size
will also be used for per_device_train_batch_size
, which can be a bit confusing.
class TabularRankingTrainer(Trainer):
def get_train_dataloader(self) -> DataLoader:
"""
We should confirm context from this data loader isn't shuffled.
```
dl = trainer.get_train_dataloader()
next(iter(dl))["contexts"]
```
"""
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
return super().get_test_dataloader(self.train_dataset)
os.environ["DISABLE_MLFLOW_INTEGRATION"] = "TRUE"
training_args = TrainingArguments(
output_dir="tabular",
num_train_epochs=50,
learning_rate=0.001,
per_device_train_batch_size=128,
per_device_eval_batch_size=128,
gradient_accumulation_steps=2,
fp16=True,
lr_scheduler_type="constant",
evaluation_strategy="steps",
eval_steps=500,
save_strategy="steps",
save_steps=500,
save_total_limit=2,
do_train=True,
# we are collecting all tabular features into a single entry
# tabular_inputs during collate function, this is to prevent
# huggingface trainer from removing these features while processing
# our dataset
remove_unused_columns=False,
load_best_model_at_end=True
)
trainer = TabularRankingTrainer(
model,
args=training_args,
data_collator=tabular_collate_fn,
train_dataset=dataset_train,
eval_dataset=dataset_validation,
compute_metrics=compute_metrics
)
# on this multi-level graded validation dataset, pairwise/listwise
# loss gives a 0.69 - 0.70 NDCG@5
train_output = trainer.train()
In computational advertising, particularly its click through rate application, pointwise loss function still remains to be the dominating approach due to:
To preserve the benefits from both pointwise and pairwise/listwise approaches, an intuitive way is to calculate weighted average of the two loss functions to take advantage from both sides [4] [5]. Given the sparsity of pairwise data, it can be beneficial to create pseudo pairs to prevent the model to be biased towards classification loss. e.g. we can form more pairs artificially by grouping impressions from different request but under the same session and user.