None deep_learning_learning_to_rank
In [1]:
# code for loading notebook's format
import os

# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', 'notebook_format'))

from formats import load_style
load_style(css_style='custom2.css', plot_style=False)
Out[1]:
In [2]:
os.chdir(path)

%load_ext watermark
%load_ext autoreload
%autoreload 2

import os
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.nn import functional as F
from datasets import Dataset
from torch.utils.data import DataLoader
from sklearn.preprocessing import LabelEncoder
from transformers import (
    PretrainedConfig,
    PreTrainedModel,
    TrainingArguments,
    Trainer
)
import torchmetrics


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

%watermark -a 'Ethen' -d -t -v -u -p torch,sklearn,numpy,pandas,torchmetrics,datasets,transformers
Author: Ethen

Last updated: 2023-09-06 04:00:55

Python implementation: CPython
Python version       : 3.10.6
IPython version      : 8.13.2

torch       : 2.0.1
sklearn     : 1.3.0
numpy       : 1.23.2
pandas      : 2.0.1
torchmetrics: 1.0.3
datasets    : 2.14.4
transformers: 4.31.0

Learning to Rank 101

Suppose we have a query denoted as $q$, and its corresponding $n$ set of documents denoted as $D = {d_1, d_2, ..., d_n}$. Our objective is to learn a function $f$ such that $f(q, D)$ will produce an ordered collection of documents, $D^*$, in descending order of relevance. Where the exact definition of relevance can vary between different applications.

In general, there're three main types of loss function for training this function: pointwise, pairwise, listwise. In this article, we'll be giving a 101 introduction to each of these variants, list out their pros and cons, as well as implementing these loss functions ourselves and training the tabular deep learning module using huggingface Trainer.

Pointwise

For pointwise approach, the aforementioned ranking task is formulated as a classic regression or classification task. The function $f(q, D)$ is simplied to $f(q, d_i)$, treating the relevance assessment of each query document independently. Suppose we have two queries that yield 2 and 3 corresponding documents respectively:

\begin{align} q_1 & \rightarrow d_1, d_2 \nonumber \\ q_2 & \rightarrow d_3, d_4, d_5 \end{align}

The training examples $x_i$ are creating by pairing each query with its associated documents.

\begin{align} x_1: q_1, d_1 \nonumber \\ x_2: q_1, d_2 \nonumber \\ x_3: q_2, d_3 \nonumber \\ x_4: q_2, d_4 \nonumber \\ x_5: q_2, d_5 \end{align}

Pros:

  • Simplicity: Existing machine learning algorithms and loss functions we might be more familiar with can be directly applied in the pointwise setting.

Cons:

  • Sub-Optimal Results: This approach may not fully capitalize on the complete information available across the entire document list for a given query, potentially leading to sub-optimal outcomes.

Pairwise

In pairwise approach, the goal remains identical to pointwise, in which we're learning a pointwise scoring function $f(q, d_i)$, but training instances are constructed using pairs of documents from the same query:

\begin{align} x_1: q_1, (d_1, d_2) \nonumber \\ x_2: q_2, (d_3, d_4) \nonumber \\ x_3: q_2, (d_3, d_5) \nonumber \\ x_4: q_2, (d_4, d_5) \end{align}

This approach introduces a new set of binary pairwise labels, derived by comparing individual relevance scores within each pair. For instance, considering the first query $q_1$, if $y_1 = 0$ (totally irrelevant) for $d_1$ and $y_2 = 3$ (highly relevant) for $d_2$, a new label $y_1 < y_2$ is assigned to the document pair $(d_1, d_2)$. This transforms the task into a binary classification learning problem.

To learn the pointwise function $f(q, d_i)$ in a pairwise manner, RankNet [1] proposed modeling the score difference probabilistically using logistic function:

\begin{align} Pr(i \succ j) = \frac{1}{1 + exp^{-(s_i - s_j)}} \end{align}

Where if document $i$ is deemed a better match than document $j$ ($i \succ j$), the probability of the scoring function assigning a higher score to $f(q, d_i) = s_i$ than $f(q, d_j) = s_j$ should be close to 1. This reflects the model's effort to understand how to score document pairs based on query information, effectively learning to rank.

Pros:

  • Pairwise Ranking Learning: Compared with pointwise model, pairwise model learns how to rank in a pairwise context, by focusing on correct classification of ordered pairs, it is potentially approximating the ultimate ranking task involving a list of documents.

Cons:

  • Pointwise Scoring: The scoring function remains pointwise, implying relative information among different documents with the same query is not yet fully harnessed.
  • Uneven pairs: If not careful with data curation where the number of documents varies largely from query to query, then the trained model may be biased towards queries with more document pairs.

Listwise

Listwise approach addresses the ranking problem in its natural form, specifically it takes in a list of instances during training so the group structure is maintained.

\begin{align} x_1&: q_1, (d_1, d_2) \nonumber \\ x_2&: q_2, (d_3, d_4, d_5) \end{align}

One of the first proposed approach is ListNet [2], where the loss is calculated between a predicted probability distribution versus target probability distribution.

\begin{align} P_{\boldsymbol{y}}\left(x_i\right)=\frac{y_i}{\sum_{j=1}^n y_j} \\ P_f\left(x_i\right)=\frac{e^{f(\boldsymbol{x_i})}}{\sum_{j=1}^n e^{f(\boldsymbol{x_j})}} \end{align}

Where:

  • $x_i$ denotes features representing a particular query-document pair.
  • $y_i$ represents each document's non-negative relevance labels.
  • $P_f\left(x_i\right)$ encodes the probability of $x_i$ appearing at the top of the ranked list, referred to as top one probability. Given these two distributions, their loss is can be measured by a standard cross entropy loss.
\begin{align} \ell(\boldsymbol{y}, f(\boldsymbol{x})) = -\sum_{i=1}^n P_{\boldsymbol{y}}\left(x_i\right) \log P_f\left(x_i\right) \end{align}

Pros:

  • Direct Ranking Learning. By formulating the problem in its native form instead of relying on proxies, it is a theoretically sound solution to approach a ranking task. i.e. Minimizing the errors in ranking the entire document list as opposed to document pairs.

Cons:

  • Pointwise Scoring. Scoring function is still pointwise, which could be sub-optimal.

Note that:

  • Different from pointwise approach that also uses softmax function and cross entropy loss, in listwise loss function, both of these are conducted over all items within the same list.
  • There're subsuquent works [3] that provides theoretically justifications for ListNet's softmax cross entropy loss. In particular they show that in a binary labeled setup, the loss bounds two popular learning to rank evaluation metrics: Mean Reciprocal Rank and Normalized Discounted Cumulative Gain.

Data

We'll be using LETOR (Learning to Rank) 4.0 dataset [6] [7]. There're also larger datasets such as, MSLR-WEB30k [8] or istella [9] which also comes in a similar raw format.

In [3]:
#!wget https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.rar
#!unrar x MQ2008.rar
In [4]:
# show case some sample raw data
input_path = 'MQ2008/Fold1/train.txt'

with open(input_path) as f:
    for _ in range(2):
        line = f.readline()
        print(line)
0 qid:10002 1:0.007477 2:0.000000 3:1.000000 4:0.000000 5:0.007470 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.000000 11:0.471076 12:0.000000 13:1.000000 14:0.000000 15:0.477541 16:0.005120 17:0.000000 18:0.571429 19:0.000000 20:0.004806 21:0.768561 22:0.727734 23:0.716277 24:0.582061 25:0.000000 26:0.000000 27:0.000000 28:0.000000 29:0.780495 30:0.962382 31:0.999274 32:0.961524 33:0.000000 34:0.000000 35:0.000000 36:0.000000 37:0.797056 38:0.697327 39:0.721953 40:0.582568 41:0.000000 42:0.000000 43:0.000000 44:0.000000 45:0.000000 46:0.007042 #docid = GX008-86-4444840 inc = 1 prob = 0.086622

0 qid:10002 1:0.603738 2:0.000000 3:1.000000 4:0.000000 5:0.603175 6:0.000000 7:0.000000 8:0.000000 9:0.000000 10:0.000000 11:0.000000 12:0.000000 13:0.122130 14:0.000000 15:0.000000 16:0.998377 17:0.375000 18:1.000000 19:0.000000 20:0.998128 21:0.000000 22:0.000000 23:0.154578 24:0.555676 25:0.000000 26:0.000000 27:0.000000 28:0.000000 29:0.071711 30:0.000000 31:0.000000 32:0.000000 33:0.000000 34:0.000000 35:0.000000 36:0.000000 37:0.000000 38:0.000000 39:0.117399 40:0.560607 41:0.000000 42:0.280000 43:0.000000 44:0.003708 45:0.333333 46:1.000000 #docid = GX037-06-11625428 inc = 0.0031586555555558 prob = 0.0897452

Each row represents a query-document pair in the dataset, with columns structured as follows:

  • First column contains the relevance label for this specific pair. A higher relevance label signifies a greater relevance between the query and the document
  • Second column contains the query ID.
  • Subsequent columns contain various features.
  • The row concludes with a comment about the pair, which includes the document's ID.
In [5]:
def parse_raw_data(input_path):
    labels = []
    query_ids = []
    features = []
    with open(input_path) as f:
        for line in f:
            # filter out comment about the record
            if "#" in line:
                line = line[:line.index("#")]

            splitted_line = line.strip().split(" ")
            label = int(splitted_line[0])
            labels.append(label)

            query_id = splitted_line[1]
            query_ids.append(query_id)

            feature = [float(feature_str.split(':')[1]) for feature_str in splitted_line[2:]]
            features.append(feature)

    df = pd.DataFrame(features)
    df["context"] = query_ids
    df["label"] = labels
    return df
In [6]:
# concatenate data under different cross validation folds together
input_paths = [
    'MQ2008/Fold1/train.txt',
    'MQ2008/Fold2/train.txt',
    'MQ2008/Fold3/train.txt',
    'MQ2008/Fold4/train.txt',
    'MQ2008/Fold5/train.txt'
]

df_list = []
for input_path in input_paths:
    df = parse_raw_data(input_path)
    df_list.append(df)

df_train = pd.concat(df_list, ignore_index=True)
df_train["split"] = "train"
print(df_train.shape)
df_train.head()
(45633, 49)
Out[6]:
0 1 2 3 4 5 6 7 8 9 ... 39 40 41 42 43 44 45 context label split
0 0.007477 0.0 1.0 0.0 0.007470 0.0 0.0 0.0 0.0 0.0 ... 0.582568 0.00 0.00 0.0 0.000000 0.000000 0.007042 qid:10002 0 train
1 0.603738 0.0 1.0 0.0 0.603175 0.0 0.0 0.0 0.0 0.0 ... 0.560607 0.00 0.28 0.0 0.003708 0.333333 1.000000 qid:10002 0 train
2 0.214953 0.0 0.0 0.0 0.213819 0.0 0.0 0.0 0.0 0.0 ... 1.000000 0.00 0.00 0.0 1.000000 1.000000 0.021127 qid:10002 0 train
3 0.000000 0.0 1.0 0.0 0.000000 0.0 0.0 0.0 0.0 0.0 ... 0.000000 0.25 1.00 0.0 0.000000 0.000000 0.000000 qid:10002 0 train
4 1.000000 1.0 0.0 0.0 1.000000 0.0 0.0 0.0 0.0 0.0 ... 0.730347 1.00 0.84 0.0 0.184564 0.666667 0.000000 qid:10002 0 train

5 rows Ă— 49 columns

In [7]:
input_paths = [
    'MQ2008/Fold1/vali.txt',
    'MQ2008/Fold2/vali.txt',
    'MQ2008/Fold3/vali.txt',
    'MQ2008/Fold4/vali.txt',
    'MQ2008/Fold5/vali.txt'
]

df_list = []
for input_path in input_paths:
    df = parse_raw_data(input_path)
    df_list.append(df)

df_validation = pd.concat(df_list, ignore_index=True)
df_validation["split"] = "validation"
print(df_validation.shape)
df_validation.head()
(15211, 49)
Out[7]:
0 1 2 3 4 5 6 7 8 9 ... 39 40 41 42 43 44 45 context label split
0 1.000000 0.0 0.0 0.0 1.000000 0.0 0.0 0.0 0.0 0.0 ... 0.711831 0.50 0.290698 0.0 0.197431 0.50 0.000000 qid:15928 0 validation
1 0.003315 0.0 1.0 0.0 0.005525 0.0 0.0 0.0 0.0 0.0 ... 0.313640 0.50 0.255814 0.0 0.212859 0.25 0.214286 qid:15928 0 validation
2 0.093923 0.0 0.0 0.0 0.093923 0.0 0.0 0.0 0.0 0.0 ... 0.760799 0.25 0.127907 0.0 0.309468 1.00 0.357143 qid:15928 1 validation
3 0.065193 0.0 0.0 0.0 0.065193 0.0 0.0 0.0 0.0 0.0 ... 1.000000 0.25 0.069767 0.0 0.178784 0.50 0.428571 qid:15928 1 validation
4 0.064088 0.0 0.0 0.0 0.064088 0.0 0.0 0.0 0.0 0.0 ... 0.784191 0.50 0.255814 0.0 0.012703 0.25 0.214286 qid:15928 0 validation

5 rows Ă— 49 columns

In [8]:
# convert context column to numerical indices, PyTorch doesn't take in string field
label_encoder = LabelEncoder()

df = pd.concat([df_train, df_validation], ignore_index=True)
df["context"] = label_encoder.fit_transform(df["context"])

# we can experiment with binary relevance label or the default graded relevance label
# df.loc[df["label"] == 2, "label"] = 1

print(df.shape)
df["label"].value_counts()
(60844, 49)
Out[8]:
label
0    49116
1     8004
2     3724
Name: count, dtype: int64
In [9]:
df_train = df[df["split"] == "train"].drop(columns=["split"]).reset_index(drop=True)
df_validation = df[df["split"] == "validation"].drop(columns=["split"]).reset_index(drop=True)
dataset_train = Dataset.from_pandas(df_train)
dataset_validation = Dataset.from_pandas(df_validation)
dataset_train
/usr/local/lib/python3.10/dist-packages/datasets/table.py:761: UserWarning: The DataFrame has column names of mixed type. They will be converted to strings and not roundtrip correctly.
  return cls(pa.Table.from_pandas(*args, **kwargs))
Out[9]:
Dataset({
    features: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', 'context', 'label'],
    num_rows: 45633
})

These subsequent function/class for using deep learning on tabular data closely follows the ones introduced in Deep Learning for Tabular Data - PyTorch. [nbviewer][html]

In [10]:
# example feature config
# tabular_features_config = {
#     "0": {
#         "dtype": "numerical"
#     }
# }
columns = [i for i in range(46)]
tabular_features_config = {str(i): {"dtype": "numerical"} for i in columns}


def tabular_collate_fn(batch):
    """
    Use in conjunction with Dataloader's collate_fn for tabular data.

    Returns
    -------
    batch : dict
        Dictionary with three keys: tabular_inputs, contexts, and labels. Tabular
        inputs is a nested field, where each element is a feature_name -> float tensor
        mapping. Contexts defines examples that share the same context/query. e.g.
        {
            'tabular_inputs': {'I1': tensor([0., 0.]), 'C1': tensor([ 888., 1313.])},
            'labels': tensor([0, 0]),
            'contexts': tensor([0, 0])
        }
    """
    labels = []
    contexts = []
    tabular_inputs = {}
    for example in batch:
        label = example["label"]
        labels.append(label)

        context = example["context"]
        contexts.append(context)

        for name in tabular_features_config:
            feature = example[name]
            if name not in tabular_inputs:
                tabular_inputs[name] = [feature]
            else:
                tabular_inputs[name].append(feature)

    for name in tabular_inputs:
        tabular_inputs[name] = torch.FloatTensor(tabular_inputs[name])

    batch = {
        "tabular_inputs": tabular_inputs,
        "labels": torch.LongTensor(labels),
        "contexts": torch.LongTensor(contexts)
    }
    return batch
In [11]:
data_loader = DataLoader(dataset_train, batch_size=2, collate_fn=tabular_collate_fn)
batch = next(iter(data_loader))
batch
Out[11]:
{'tabular_inputs': {'0': tensor([0.0075, 0.6037]),
  '1': tensor([0., 0.]),
  '2': tensor([1., 1.]),
  '3': tensor([0., 0.]),
  '4': tensor([0.0075, 0.6032]),
  '5': tensor([0., 0.]),
  '6': tensor([0., 0.]),
  '7': tensor([0., 0.]),
  '8': tensor([0., 0.]),
  '9': tensor([0., 0.]),
  '10': tensor([0.4711, 0.0000]),
  '11': tensor([0., 0.]),
  '12': tensor([1.0000, 0.1221]),
  '13': tensor([0., 0.]),
  '14': tensor([0.4775, 0.0000]),
  '15': tensor([0.0051, 0.9984]),
  '16': tensor([0.0000, 0.3750]),
  '17': tensor([0.5714, 1.0000]),
  '18': tensor([0., 0.]),
  '19': tensor([0.0048, 0.9981]),
  '20': tensor([0.7686, 0.0000]),
  '21': tensor([0.7277, 0.0000]),
  '22': tensor([0.7163, 0.1546]),
  '23': tensor([0.5821, 0.5557]),
  '24': tensor([0., 0.]),
  '25': tensor([0., 0.]),
  '26': tensor([0., 0.]),
  '27': tensor([0., 0.]),
  '28': tensor([0.7805, 0.0717]),
  '29': tensor([0.9624, 0.0000]),
  '30': tensor([0.9993, 0.0000]),
  '31': tensor([0.9615, 0.0000]),
  '32': tensor([0., 0.]),
  '33': tensor([0., 0.]),
  '34': tensor([0., 0.]),
  '35': tensor([0., 0.]),
  '36': tensor([0.7971, 0.0000]),
  '37': tensor([0.6973, 0.0000]),
  '38': tensor([0.7220, 0.1174]),
  '39': tensor([0.5826, 0.5606]),
  '40': tensor([0., 0.]),
  '41': tensor([0.0000, 0.2800]),
  '42': tensor([0., 0.]),
  '43': tensor([0.0000, 0.0037]),
  '44': tensor([0.0000, 0.3333]),
  '45': tensor([0.0070, 1.0000])},
 'labels': tensor([0, 0]),
 'contexts': tensor([0, 0])}

Model

Learning to rank based approaches regardless of whether it's pairwise or listwise requires data from the same context/group to be in the same mini-batch. Given our input data is in a pointwise format where each row represents a query-document pair, some additional transformations are necessary. We'll use some toy examples to illustrate these points before showing the full blown implementation.

In pairwise loss, the trick is to expand the original 1d tensor for computing a pairwise difference.

  • Context's pairwise difference signals which pairs belong to the same context. Pairs belonging to different context should be masked out during the loss calculation.
  • Label's pairwise difference is symmetric, and we only need to consider pairs where the difference is positive and convert it to a binary label.
  • Prediction score's pairwise difference will be input to our loss function.
In [12]:
pred_tensor = torch.FloatTensor([1, 2, 3, 1, 2])
target_tensor = torch.Tensor([0, 1, 1, 1, 0])
context_tensor = torch.Tensor([0, 0, 0, 1, 1])

# 6 total positive pairs
target_pairwise_diff = target_tensor.unsqueeze(0) - target_tensor.unsqueeze(1)
print("label:\n ", target_pairwise_diff)

# context pairs, shows pairs from the first 3 examples belonging to the same context 0,
# and pairs from example 4, 5 belongs to the same context 1
context_pairwise_diff = context_tensor.unsqueeze(0) == context_tensor.unsqueeze(1)
print("context:\n ", context_pairwise_diff)
label:
  tensor([[ 0.,  1.,  1.,  1.,  0.],
        [-1.,  0.,  0.,  0., -1.],
        [-1.,  0.,  0.,  0., -1.],
        [-1.,  0.,  0.,  0., -1.],
        [ 0.,  1.,  1.,  1.,  0.]])
context:
  tensor([[ True,  True,  True, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True, False, False],
        [False, False, False,  True,  True],
        [False, False, False,  True,  True]])

In listwise loss, loss are calculated once for all data within the same context/group. Hence apart from the predicted scores and target/labels, we also need to know which examples belong to the same context/group. One common way to do this is to assume the examples are already sorted by context, and have a group length variable which stores each group's instance count.

In the example below, we have 5 observations belonging to 2 contexts/groups. [3, 2] means the fist 3 items belongs to the first group, whereas the next 2 items belongs to the second group. torch.split then splits the original single tensor into grouped chunks, in a vanilla implementation, we can loop through each group and compute the cross entropy loss.

In [13]:
pred_tensor = torch.FloatTensor([1, 2, 3, 1, 2])
target_tensor = torch.Tensor([0, 1, 1, 1, 0])
group_length_tensor = torch.LongTensor([3, 2])

losses = []
for pred, target in zip(
    torch.split(pred_tensor, group_length_tensor.tolist()),
    torch.split(target_tensor, group_length_tensor.tolist())
):
    # equivalent to cross entropy
    # loss = -torch.dot(target, F.log_softmax(pred, dim=0))
    loss = F.cross_entropy(pred, target)
    losses.append(loss)

loop_listwise_loss = torch.stack(losses)
loop_listwise_loss
Out[13]:
tensor([1.8152, 1.3133])

A cleaner solution would be to pad these grouped chunks and perform the calculation in a batched manner. The padding values do matter, where we'll use an extremely small prediction score with 0 as its corresponding label.

In [14]:
pred_group = torch.split(pred_tensor, group_length_tensor.tolist())
target_group = torch.split(target_tensor, group_length_tensor.tolist())
pred_pad = pad_sequence(pred_group, batch_first=True, padding_value=-1e4)
target_pad = pad_sequence(target_group, batch_first=True, padding_value=0.0)
print(pred_pad)
print(target_pad)

loss_fct = nn.CrossEntropyLoss(reduction="none")
batch_listwise_loss = loss_fct(pred_pad, target_pad)
print(batch_listwise_loss)

assert torch.equal(loop_listwise_loss, batch_listwise_loss)
tensor([[ 1.0000e+00,  2.0000e+00,  3.0000e+00],
        [ 1.0000e+00,  2.0000e+00, -1.0000e+04]])
tensor([[0., 1., 1.],
        [1., 0., 0.]])
tensor([1.8152, 1.3133])
In [15]:
def compute_pairwise_loss(logits, labels, contexts):    
    logits_positive = logits[:, 0]
    logits_positive_diff = logits_positive.unsqueeze(0) - logits_positive.unsqueeze(1)
    
    labels_pairwise_diff = labels.unsqueeze(0) - labels.unsqueeze(1)
    labels_positive_mask = labels_pairwise_diff > 0
    context_pairwise_diff = contexts.unsqueeze(0) == contexts.unsqueeze(1)
    
    loss_fct = nn.LogSigmoid()
    logits_positive_masked = torch.masked_select(logits_positive_diff, labels_positive_mask * context_pairwise_diff)
    if len(logits_positive_masked) == 0:
        pairwise_loss = torch.tensor(0.0, requires_grad=True).to(logits.device)
    else:
        pairwise_loss = -loss_fct(logits_positive_masked).mean()

    return pairwise_loss


def compute_listwise_loss(logits, labels, contexts):
    sorted_contexts, sorted_indices = torch.sort(contexts)
    sorted_labels = labels[sorted_indices]
    logits_positive = logits[:, 0]
    sorted_logits_positive = logits_positive[sorted_indices]

    # contexts should already be sorted, using unique_consecutive as opposed to
    # unique for avoiding additional sorting
    unique_contexts, group_length = torch.unique_consecutive(contexts, return_counts=True)

    logits_positive_group = torch.split(sorted_logits_positive, group_length.tolist())
    labels_group = torch.split(sorted_labels, group_length.tolist())

    # for logits, pad with an extremely small prediction score, this default value works even
    # when using float16 mixed precision training
    logits_pad = pad_sequence(logits_positive_group, batch_first=True, padding_value=-1e+4)
    labels_pad = pad_sequence(labels_group, batch_first=True, padding_value=0.0)

    # we ensure there are more than 1 examples and at least 1 positive
    # examples per group/context
    group_mask = (group_length > 1) & (labels_pad.sum(dim=1) > 0)
    logits_pad = logits_pad[group_mask]
    labels_pad = labels_pad[group_mask]

    if len(logits_pad) > 0:
        loss_fct = nn.CrossEntropyLoss(reduction="mean")
        listwise_loss = loss_fct(logits_pad, labels_pad.float())
    else:
        listwise_loss = torch.tensor(0.0, requires_grad=True).to(logits.device)            

    return listwise_loss
In [16]:
def get_mlp_layers(input_dim: int, mlp_config):
    """
    Construct MLP, a.k.a. Feed forward layers based on input config.

    Parameters
    ----------
    input_dim : 
        Input dimension for the first layer.

    mlp_config : list of dictionary with mlp spec.
        An example is shown below, the only mandatory parameter is hidden size.
        ```
        [
            {
                "hidden_size": 1024,
                "dropout_p": 0.1,
                "activation_function": "ReLU",
                "activation_function_kwargs": {},
                "normalization_function": "LayerNorm"
                "normalization_function_kwargs": {"eps": 1e-05}
            }
        ]
        ```

    Returns
    -------
    nn.Sequential :
        Sequential layer converted from input mlp_config. If mlp_config
        is None, then this returned value will also be None.

    current_dim :
        Dimension for the last layer.
    """
    if mlp_config is None:
        return None, input_dim

    layers = []
    current_dim = input_dim
    for config in mlp_config:
        hidden_size = config["hidden_size"]
        dropout_p = config.get("dropout_p", 0.0)
        activation_function = config.get("activation_function")
        activation_function_kwargs = config.get("activation_function_kwargs", {})
        normalization_function = config.get("normalization_function")
        normalization_function_kwargs = config.get("normalization_function_kwargs", {})

        linear = nn.Linear(current_dim, hidden_size)
        layers.append(linear)

        if normalization_function:
            normalization = getattr(nn, normalization_function)(hidden_size, **normalization_function_kwargs)
            layers.append(normalization)

        if activation_function:
            activation = getattr(nn, activation_function)(**activation_function_kwargs)
            layers.append(activation)

        dropout = nn.Dropout(p=dropout_p)
        layers.append(dropout)
        current_dim = hidden_size

    return nn.Sequential(*layers), current_dim
In [17]:
class TabularModelConfig(PretrainedConfig):

    def __init__(
        self,
        tabular_features_config=None,
        mlp_config=None,
        num_labels=1,
        loss_name="listwise",
        **kwargs
    ):
        super().__init__(**kwargs)
        self.tabular_features_config = tabular_features_config
        self.mlp_config = mlp_config
        self.num_labels = num_labels
        self.loss_name = loss_name
In [18]:
class TabularModel(PreTrainedModel):

    config_class = TabularModelConfig

    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.embeddings, output_dim = self.init_tabular_parameters(config.tabular_features_config)
        self.mlp, output_dim = get_mlp_layers(output_dim, config.mlp_config)
        self.head = nn.Linear(output_dim, config.num_labels)

        if config.loss_name == "pairwise":
            self.loss_function = compute_pairwise_loss
        else:
            self.loss_function = compute_listwise_loss

    def forward(self, tabular_inputs, contexts, labels=None):
        concatenated_inputs = self.concatenate_tabular_inputs(
            tabular_inputs,
            self.config.tabular_features_config
        )
        mlp_outputs = self.mlp(concatenated_inputs)
        logits = self.head(mlp_outputs)

        loss = None
        if labels is not None:
            loss = self.loss_function(logits, labels, contexts)

        return loss, logits, contexts

    def concatenate_tabular_inputs(self, inputs, tabular_features_config):
        numerical_inputs = []
        categorical_inputs = []
        for name, config in tabular_features_config.items():
            if config["dtype"] == "categorical":
                feature_name = f"{name}_embedding"
                share_embedding = config.get("share_embedding")
                if share_embedding:
                    feature_name = f"{share_embedding}_embedding"

                embedding = self.embeddings[feature_name]
                features = inputs[name].type(torch.long)
                embed = embedding(features)
                categorical_inputs.append(embed)
            elif config["dtype"] == "numerical":
                features = inputs[name].type(torch.float32)
                if len(features.shape) == 1:
                    features = features.unsqueeze(dim=1)

                numerical_inputs.append(features)

        if len(numerical_inputs) > 0:
            numerical_inputs = torch.cat(numerical_inputs, dim=-1)

        categorical_inputs.append(numerical_inputs)
        concatenated_inputs = torch.cat(categorical_inputs, dim=-1)
        return concatenated_inputs

    def init_tabular_parameters(self, tabular_features_config):
        embeddings = {}
        output_dim = 0
        for name, config in tabular_features_config.items():
            if config["dtype"] == "categorical":
                feature_name = f"{name}_embedding"
                # create new embedding layer for categorical features if share_embedding is None
                share_embedding = config.get("share_embedding")
                if share_embedding:
                    share_embedding_config = model.pairwise_features_info[share_embedding]
                    embedding_size = share_embedding_config["embedding_size"]
                else:
                    embedding_size = config["embedding_size"]
                    embedding = nn.Embedding(config["vocab_size"], embedding_size)
                    embeddings[feature_name] = embedding

                output_dim += embedding_size
            elif config["dtype"] == "numerical":
                output_dim += 1

        return nn.ModuleDict(embeddings), output_dim
In [19]:
mlp_config = [
    {
        "hidden_size": 1024,
        "dropout_p": 0.1,
        "activation_function": "ReLU",
        "normalization_function": "LayerNorm"
    },
    {
        "hidden_size": 512,
        "dropout_p": 0.1,
        "activation_function": "ReLU",
        "normalization_function": "LayerNorm"
    },
    {
        "hidden_size": 256,
        "dropout_p": 0.1,
        "activation_function": "ReLU",
        "normalization_function": "LayerNorm"
    }
]
config = TabularModelConfig(tabular_features_config, mlp_config, loss_name="listwise")
model = TabularModel(config).to(device)
print("# of parameters: ", model.num_parameters())
model
# of parameters:  708097
Out[19]:
TabularModel(
  (embeddings): ModuleDict()
  (mlp): Sequential(
    (0): Linear(in_features=46, out_features=1024, bias=True)
    (1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Dropout(p=0.1, inplace=False)
    (4): Linear(in_features=1024, out_features=512, bias=True)
    (5): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (6): ReLU()
    (7): Dropout(p=0.1, inplace=False)
    (8): Linear(in_features=512, out_features=256, bias=True)
    (9): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (10): ReLU()
    (11): Dropout(p=0.1, inplace=False)
  )
  (head): Linear(in_features=256, out_features=1, bias=True)
)
In [20]:
def compute_metrics(eval_preds, round_digits: int = 3):
    """Reports NDCG metrics"""
    (y_pred, context), y_true = eval_preds

    y_score = y_pred[:, 0]

    ndcg_metrics = torchmetrics.retrieval.RetrievalNormalizedDCG(top_k=5)
    ndcg = ndcg_metrics(torch.FloatTensor(y_score), torch.FloatTensor(y_true), indexes=torch.LongTensor(context))
    return {
        'ndcg': ndcg
    }

When training a learning to rank model, an important detail is to prevent data shuffling in our data loader so data from the same context can be grouped together in a mini-batch. At the time of writing this, huggingface transformer's Trainer will by default enable shuffling on our train dataset. We quickly override that behaviour by using get_test_dataloader even for our train dataloader. This addresses the issue with the least amount of code with the quirk being now per_device_eval_batch_size will also be used for per_device_train_batch_size, which can be a bit confusing.

In [21]:
class TabularRankingTrainer(Trainer):

    def get_train_dataloader(self) -> DataLoader:
        """
        We should confirm context from this data loader isn't shuffled.

        ```
        dl = trainer.get_train_dataloader()
        next(iter(dl))["contexts"]
        ```
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")

        return super().get_test_dataloader(self.train_dataset)
In [22]:
os.environ["DISABLE_MLFLOW_INTEGRATION"] = "TRUE"
training_args = TrainingArguments(
    output_dir="tabular",
    num_train_epochs=50,
    learning_rate=0.001,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    gradient_accumulation_steps=2,
    fp16=True,
    lr_scheduler_type="constant",
    evaluation_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=2,
    do_train=True,
    # we are collecting all tabular features into a single entry
    # tabular_inputs during collate function, this is to prevent
    # huggingface trainer from removing these features while processing
    # our dataset
    remove_unused_columns=False,
    load_best_model_at_end=True
)

trainer = TabularRankingTrainer(
    model,
    args=training_args,
    data_collator=tabular_collate_fn,
    train_dataset=dataset_train,
    eval_dataset=dataset_validation,
    compute_metrics=compute_metrics
)

# on this multi-level graded validation dataset, pairwise/listwise
# loss gives a 0.69 - 0.70 NDCG@5
train_output = trainer.train()
/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(
Could not estimate the number of tokens of the input, floating-point operations will not be computed
[8900/8900 08:24, Epoch 49/50]
Step Training Loss Validation Loss Ndcg
500 21.815300 20.822407 0.489811
1000 21.282800 20.460518 0.510252
1500 20.930400 20.099937 0.522047
2000 20.950900 19.581291 0.556249
2500 20.370300 19.352951 0.573809
3000 20.088400 18.906128 0.593706
3500 19.419700 18.469860 0.610612
4000 19.186300 18.217682 0.618409
4500 19.342700 17.893827 0.647029
5000 18.695400 17.563608 0.658074
5500 18.484600 17.467468 0.663601
6000 17.947200 17.103064 0.675647
6500 17.924100 16.729994 0.684831
7000 17.969500 16.560793 0.689212
7500 17.559800 16.929613 0.694637
8000 17.424100 16.342243 0.697590
8500 17.035300 16.118475 0.698370

In computational advertising, particularly its click through rate application, pointwise loss function still remains to be the dominating approach due to:

  • Calibrated Score. For ad auction to properly take place, a model's prediction score needs to be treated as a click probability instead of a score that only denotes ordering or perference.
  • Data Sparsity. Pairwise/listwise approach relies on events that have positive outcomes. These approaches compare records with positive events to those without for building their loss functions. However, in practice, these positive events can be sparse, meaning there are far fewer instances of user engagement (clicks) than non-engagement. This sparsity implies that using pairwise or listwise methods would result in a significant loss of available data and might hinder downstream performance. Pointwise approach doesn't have this limitation and can make better use of available data.

To preserve the benefits from both pointwise and pairwise/listwise approaches, an intuitive way is to calculate weighted average of the two loss functions to take advantage from both sides [4] [5]. Given the sparsity of pairwise data, it can be beneficial to create pseudo pairs to prevent the model to be biased towards classification loss. e.g. we can form more pairs artificially by grouping impressions from different request but under the same session and user.

Reference

  • [1] Chris Burges, Tal Shaked, Erin Renshaw, Ari Lazier, Matt Deeds, Nicole Hamilton, Greg Hullender - Learning to Rank using Gradient Descent - 2005
  • [2] Zhe Cao, Tao Qin, Ming-Feng Tsai, et al. - Learning to Rank: From Pairwise Approach to Listwise Approach - 2007
  • [3] Sebastian Bruch, Xuanhui Wang, Michael Bendersky, Marc Najork - An Analysis of the Softmax Cross Entropy Loss for Learning-to-Rank with Binary Relevance - 2019
  • [4] Cheng Li, Yue Lu, Qiaozhu Mei, Dong Wang, Sandeep Pandey - Click-through Prediction for Advertising in Twitter Timeline - 2015
  • [5] Shuguang Han et al. - Joint Optimization of Ranking and Calibration with Contextualized Hybrid Model - 2022
  • [6] Tao Qin, Tie-Yan Liu - Introducing LETOR 4.0 Datasets - 2013
  • [7] LETOR: Learning to Rank for Information Retrieval - LETOR 4.0
  • [8] Microsoft Learning to Rank Datasets
  • [9] Blog: Istella Learning to Rank dataset
  • [10] Blog: Learning to rank is good for your ML career - Part 2: let’s implement ListNet!