None deep_learning_tabular
In [1]:
# code for loading notebook's format
import os

# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', 'notebook_format'))

from formats import load_style
load_style(css_style='custom2.css', plot_style=False)
Out[1]:
In [2]:
os.chdir(path)

%load_ext watermark
%load_ext autoreload
%autoreload 2

import os
import yaml
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import sklearn.metrics as metrics
from torch.nn import functional as F
from torch.utils.data import DataLoader
from datasets import (
    Dataset,
    load_dataset,
    disable_progress_bar
)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder, MinMaxScaler
from transformers import (
    PretrainedConfig,
    PreTrainedModel,
    TrainingArguments,
    Trainer
)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

%watermark -a 'Ethen' -d -t -v -u -p transformers,datasets,torch,numpy,pandas,sklearn
Author: Ethen

Last updated: 2023-08-23 21:35:27

Python implementation: CPython
Python version       : 3.10.6
IPython version      : 8.13.2

transformers: 4.31.0
datasets    : 2.14.4
torch       : 2.0.1
numpy       : 1.23.2
pandas      : 2.0.1
sklearn     : 1.3.0

Deep Learning for Tabular Data

While deep learning's achievements are often highlighted in areas like computer vision and natural language processing, a lesser-discussed yet potent application involves applying deep learning to tabular data.

A key technique to maximize deep learning's potential with tabular data involves using embeddings for categorical variables [4]. This means representing categories in a lower-dimensional numeric space, capturing intricate relationships between them. For instance, this could reveal geographic connections between high-cardinality categorical features like zip codes, without explicit guidance. Even for continuous features such as days of the week, it's still worth exploring the potential advantages of treating them as categorical features and utilizing embeddings.

Furthermore, embeddings offer benefits beyond their initial use. Once trained, these embeddings can be employed in other contexts. For example, they can serve as features for tree-based models, granting them the enriched knowledge gleaned from deep learning. This cross-application of embeddings underscores their versatility and their ability to enhance various modeling techniques.

In this article, we'll be looking at some bare minimum steps for training a self-defined deep learning model and training it using huggingface Trainer.

Data Preprocessing

We'll be using a downsampled criteo dataset, which originated from a Kaggle competition [2]. Though after the competition ended, those original data files became unavailable on the platform. We turned to an alternative source for downloading a similar dataset [1]. Each row corresponds to a display ad served by Criteo. Positive (clicked) and negatives (non-clicked) examples have both been subsampled at different rates in order to reduce the dataset size. Fields in this dataset includes:

  • Label: Target variable that indicates if an ad was clicked (1) or not (0).
  • I1-I13: A total of 13 columns of integer features (mostly count features).
  • C1-C26: A total of 26 columns of categorical features. The values of these features have been hashed onto 32 bits for anonymization purposes.

Unfortunately, the meanings of these features aren't disclosed.

Note, there are many ways to implement a data preprocessing step, the baseline approach we'll be performing here is to:

  • Encode categorical columns as distinct numerical ids.
  • Standardize/Scale numerical columns.
  • Given the un-balanced dataset, we perform random downsampling on the negative class for our training set, while keeping the test set unbalanced.
# one time code for creating a sampled criteo dataset
import gzip
import pandas as pd


def parse_criteo_data(gzip_file: str, num_records: int, output_path: str):
    """
    Parse gzipped criteo dataset and save it into a tabular parquet format.
    """
    columns = [
        'label', 'I1', 'I2', 'I3', 'I4', 'I5', 'I6', 'I7', 'I8', 'I9', 'I10',
        'I11', 'I12', 'I13', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8',
        'C9', 'C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18',
        'C19', 'C20', 'C21', 'C22', 'C23', 'C24', 'C25', 'C26'
    ]

    dtype = {}
    for col in columns:
        if "C" in col:
            dtype[col] = "string"
        elif col == "label":
            dtype[col] = "int"
        else:
            dtype[col] = "float"

    lines = []
    with gzip.open(gzip_file, 'r') as f_in:
        for i in range(num_records):
            line = f_in.readline()
            line = str(line, encoding="utf-8")
            line = line.strip().split("\t")
            lines.append(line)

    df = pd.DataFrame(lines, columns=columns)
    df = df.replace("", None)
    df = df.astype(dtype)
    df.to_parquet(output_path, index=False)


gzip_file = "day_0.gz"
num_records = 1000000
output_path = "criteo_sampled.parquet"
parse_criteo_data(gzip_file, num_records, output_path)
In [3]:
input_path = "criteo_data/criteo_sampled.parquet"
df = pd.read_parquet(input_path)
print(df.shape)
df.head()
(1000000, 40)
Out[3]:
label I1 I2 I3 I4 I5 I6 I7 I8 I9 ... C17 C18 C19 C20 C21 C22 C23 C24 C25 C26
0 1 5.0 110.0 NaN 16.0 NaN 1.0 0.0 14.0 7.0 ... d20856aa b8170bba 9512c20b c38e2f28 14f65a5d 25b1b089 d7c1fc0b 7caf609c 30436bfc ed10571d
1 0 32.0 3.0 5.0 NaN 1.0 0.0 0.0 61.0 5.0 ... d20856aa a1eb1511 9512c20b febfd863 a3323ca1 c8e1ee56 1752e9e8 75350c8a 991321ea b757e957
2 0 NaN 233.0 1.0 146.0 1.0 0.0 0.0 99.0 7.0 ... d20856aa 628f1b8d 9512c20b c38e2f28 14f65a5d 25b1b089 d7c1fc0b 34a9b905 ff654802 ed10571d
3 0 NaN 24.0 NaN 11.0 24.0 NaN 0.0 56.0 3.0 ... 1f7fc70b a1eb1511 9512c20b <NA> <NA> <NA> dc209cd3 b8a81fb0 30436bfc b757e957
4 0 60.0 223.0 6.0 15.0 5.0 0.0 0.0 1.0 8.0 ... d20856aa d9f758ff 9512c20b c709ec07 2b07677e a89a92a5 aa137169 e619743b cdc3217e ed10571d

5 rows × 40 columns

In [4]:
sparse_features = ['C' + str(i) for i in range(1, 27)]
dense_features = ['I' + str(i) for i in range(1, 14)]
feature_names = dense_features + sparse_features

df[sparse_features] = df[sparse_features].fillna('-1')
df[dense_features] = df[dense_features].fillna(0)
In [5]:
# label encoding for categorical/sparse features
# and scaling for numerical/dense features
ordinal_encoder = OrdinalEncoder(min_frequency=30)
df[sparse_features] = ordinal_encoder.fit_transform(df[sparse_features])

min_max_scaler = MinMaxScaler(feature_range=(0, 1))
df[dense_features] = min_max_scaler.fit_transform(df[dense_features])
df.head()
Out[5]:
label I1 I2 I3 I4 I5 I6 I7 I8 I9 ... C17 C18 C19 C20 C21 C22 C23 C24 C25 C26
0 1 0.000076 0.013750 0.000000 0.000039 0.000000 0.001033 0.0 0.000771 0.001964 ... 3.0 138.0 9.0 970.0 161.0 221.0 1229.0 1401.0 2.0 20.0
1 0 0.000488 0.000375 0.016129 0.000000 0.000185 0.000000 0.0 0.003186 0.001403 ... 3.0 117.0 9.0 1251.0 1114.0 1073.0 126.0 1317.0 15.0 15.0
2 0 0.000000 0.029125 0.003226 0.000353 0.000185 0.000000 0.0 0.005139 0.001964 ... 3.0 75.0 9.0 970.0 161.0 221.0 1229.0 610.0 32.0 20.0
3 0 0.000000 0.003000 0.000000 0.000027 0.004431 0.000000 0.0 0.002929 0.000842 ... 1.0 117.0 9.0 0.0 0.0 0.0 1255.0 2059.0 2.0 15.0
4 0 0.000916 0.027875 0.019355 0.000036 0.000923 0.000000 0.0 0.000103 0.002245 ... 3.0 160.0 9.0 1255.0 1711.0 1374.0 948.0 2566.0 19.0 20.0

5 rows × 40 columns

In [6]:
df["label"].value_counts()
Out[6]:
label
0    970960
1     29040
Name: count, dtype: int64
In [7]:
def downsample_negative(df: pd.DataFrame, frac: float = 0.5, random_state: int = 1234):
    """Given a binary classification task with 0/1 labels, downsample negative class (class 0) with the
    specified fraction parameter.
    """
    df_majority = df[df["label"] == 0]
    df_minority = df[df["label"] == 1]
    df_downsampled_majority = df_majority.sample(frac=frac, random_state=random_state)
    df_downsampled = pd.concat([df_downsampled_majority, df_minority])
    # shuffle the combined data frame
    df_downsampled = df_downsampled.sample(frac=1, random_state=random_state).reset_index(drop=True)
    return df_downsampled
In [8]:
df_train, df_test = train_test_split(df, test_size=0.1, random_state=1234, stratify=df["label"])
df_test = df_test.reset_index(drop=True)
df_train_downsampled = downsample_negative(df_train)

dataset_train = Dataset.from_pandas(df_train_downsampled)
dataset_test = Dataset.from_pandas(df_test)
print(dataset_train)
dataset_train[0]
Dataset({
    features: ['label', 'I1', 'I2', 'I3', 'I4', 'I5', 'I6', 'I7', 'I8', 'I9', 'I10', 'I11', 'I12', 'I13', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18', 'C19', 'C20', 'C21', 'C22', 'C23', 'C24', 'C25', 'C26'],
    num_rows: 463068
})
Out[8]:
{'label': 0,
 'I1': 0.0,
 'I2': 0.005625,
 'I3': 0.0,
 'I4': 0.0,
 'I5': 0.0007385524372230429,
 'I6': 0.0,
 'I7': 0.0,
 'I8': 0.0,
 'I9': 0.0005611672278338945,
 'I10': 0.0,
 'I11': 0.013513513513513514,
 'I12': 0.000576345639540026,
 'I13': 0.0,
 'C1': 888.0,
 'C2': 866.0,
 'C3': 1134.0,
 'C4': 292.0,
 'C5': 379.0,
 'C6': 0.0,
 'C7': 1490.0,
 'C8': 141.0,
 'C9': 2.0,
 'C10': 631.0,
 'C11': 113.0,
 'C12': 1406.0,
 'C13': 4.0,
 'C14': 400.0,
 'C15': 882.0,
 'C16': 11.0,
 'C17': 3.0,
 'C18': 138.0,
 'C19': 11.0,
 'C20': 114.0,
 'C21': 1011.0,
 'C22': 113.0,
 'C23': 1209.0,
 'C24': 835.0,
 'C25': 2.0,
 'C26': 23.0}

We'll specify a config mapping for tabular features that we'll be using across our batch collate function as well as model. This config mapping have features we wish to leverage as keys, and different value/enum specifying whether the field is numerical or categorical type. This will be beneficial to inform our model about the embedding size required for a categorical type as well as how many numerical fields are there to initiate the dense/feed forward layers.

In [9]:
# demonstrating the functionality with 1 numerical and 1 categorical feature
tabular_features_config = {
    "I1": {
        "dtype": "numerical",
    },
    "C1": {
        "dtype": "categorical",
        "vocab_size": len(ordinal_encoder.categories_[0]),
        "embedding_size": 32
    }
}


def tabular_collate_fn(batch):
    """
    Use in conjunction with Dataloader's collate_fn for tabular data.

    Returns
    -------
    batch : dict
        Dictionary with two primary keys: tabular_inputs, and labels. Tabular
        inputs is a nested field, where each element is a feature_name -> float tensor
        mapping. e.g.
        {
            'tabular_inputs': {'I1': tensor([0., 0.]), 'C1': tensor([ 888., 1313.])},
             'labels': tensor([0, 0])
        }
    """
    labels = []
    tabular_inputs = {}
    for example in batch:
        label = example["label"]
        labels.append(label)

        for name in tabular_features_config:
            feature = example[name]
            if name not in tabular_inputs:
                tabular_inputs[name] = [feature]
            else:
                tabular_inputs[name].append(feature)

    for name in tabular_inputs:
        tabular_inputs[name] = torch.FloatTensor(tabular_inputs[name])

    batch = {
        "tabular_inputs": tabular_inputs,
        "labels": torch.LongTensor(labels)
    }
    return batch
In [10]:
data_loader = DataLoader(dataset_train, batch_size=2, collate_fn=tabular_collate_fn)
batch = next(iter(data_loader))
batch
Out[10]:
{'tabular_inputs': {'I1': tensor([0., 0.]), 'C1': tensor([ 888., 1313.])},
 'labels': tensor([0, 0])}
In [11]:
# specify all the features in config.yaml to prevent clunky display
# we'll need to update vocabulary size for each categorical features
# if we were to use a different dataset

# for category in ordinal_encoder.categories_:
#     print(len(category))
with open("features_config.yaml", "r") as f_in:
    config = yaml.safe_load(f_in)

tabular_features_config = config["tabular_features_config"]

data_loader = DataLoader(dataset_train, batch_size=2, collate_fn=tabular_collate_fn)
batch = next(iter(data_loader))
batch
Out[11]:
{'tabular_inputs': {'I1': tensor([0., 0.]),
  'I2': tensor([0.0056, 0.0180]),
  'I3': tensor([0., 0.]),
  'I4': tensor([0.0000, 0.0012]),
  'I5': tensor([0.0007, 0.0002]),
  'I6': tensor([0., 0.]),
  'I7': tensor([0., 0.]),
  'I8': tensor([0.0000, 0.0217]),
  'I9': tensor([0.0006, 0.0017]),
  'I10': tensor([0., 0.]),
  'I11': tensor([0.0135, 0.0045]),
  'I12': tensor([0.0006, 0.0168]),
  'I13': tensor([0., 0.]),
  'C1': tensor([ 888., 1313.]),
  'C2': tensor([866., 276.]),
  'C3': tensor([1134., 3660.]),
  'C4': tensor([292., 213.]),
  'C5': tensor([379., 963.]),
  'C6': tensor([0., 2.]),
  'C7': tensor([1490., 2072.]),
  'C8': tensor([141., 401.]),
  'C9': tensor([2., 4.]),
  'C10': tensor([ 631., 1574.]),
  'C11': tensor([ 113., 1499.]),
  'C12': tensor([1406., 2965.]),
  'C13': tensor([4., 8.]),
  'C14': tensor([400., 198.]),
  'C15': tensor([882.,  36.]),
  'C16': tensor([11., 26.]),
  'C17': tensor([3., 2.]),
  'C18': tensor([138.,  75.]),
  'C19': tensor([11.,  3.]),
  'C20': tensor([ 114., 1255.]),
  'C21': tensor([1011., 1711.]),
  'C22': tensor([ 113., 1374.]),
  'C23': tensor([1209., 1466.]),
  'C24': tensor([835.,  38.]),
  'C25': tensor([ 2., 32.]),
  'C26': tensor([23., 18.])},
 'labels': tensor([0, 0])}

Model

Our model architecture mainly involves: Converting categorical features into a low dimensonal embedding, these embedding outputs are then concatenated with rest of the dense features before feeding them into subsequent feed forward layers.

In [12]:
def get_mlp_layers(input_dim: int, mlp_config):
    """
    Construct MLP, a.k.a. Feed forward layers based on input config.
    
    Parameters
    ----------
    input_dim : 
        Input dimension for the first layer.

    mlp_config : list of dictionary with mlp spec.
        An example is shown below, the only mandatory parameter is hidden size.
        ```
        [
            {
                "hidden_size": 1024,
                "dropout_p": 0.1,
                "activation_function": "ReLU",
                "activation_function_kwargs": {},
                "normalization_function": "LayerNorm"
                "normalization_function_kwargs": {"eps": 1e-05}
            }
        ]
        ```

    Returns
    -------
    nn.Sequential :
        Sequential layer converted from input mlp_config. If mlp_config
        is None, then this returned value will also be None.

    current_dim :
        Dimension for the last layer.
    """
    if mlp_config is None:
        return None, input_dim

    layers = []
    current_dim = input_dim
    for config in mlp_config:
        hidden_size = config["hidden_size"]
        dropout_p = config.get("dropout_p", 0.0)
        activation_function = config.get("activation_function")
        activation_function_kwargs = config.get("activation_function_kwargs", {})
        normalization_function = config.get("normalization_function")
        normalization_function_kwargs = config.get("normalization_function_kwargs", {})

        linear = nn.Linear(current_dim, hidden_size)
        layers.append(linear)

        if normalization_function:
            normalization = getattr(nn, normalization_function)(hidden_size, **normalization_function_kwargs)
            layers.append(normalization)

        if activation_function:
            activation = getattr(nn, activation_function)(**activation_function_kwargs)
            layers.append(activation)

        dropout = nn.Dropout(p=dropout_p)
        layers.append(dropout)
        current_dim = hidden_size

    return nn.Sequential(*layers), current_dim

The next code block involves defining a config and model class following huggingface transformer's class structure [3]. This allows us to leverage its Trainer class for training and evaluating our models instead of writing custom training loops.

In [13]:
class TabularModelConfig(PretrainedConfig):

    model_type = "tabular"

    def __init__(
        self,
        tabular_features_config=None,
        mlp_config=None,
        num_labels=2,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.tabular_features_config = tabular_features_config
        self.mlp_config = mlp_config
        self.num_labels = num_labels
In [14]:
class TabularModel(PreTrainedModel):

    config_class = TabularModelConfig

    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.embeddings, output_dim = self.init_tabular_parameters(config.tabular_features_config)
        self.mlp, output_dim = get_mlp_layers(output_dim, config.mlp_config)
        self.head = nn.Linear(output_dim, config.num_labels)

    def forward(self, tabular_inputs, labels=None):
        concatenated_inputs = self.concatenate_tabular_inputs(
            tabular_inputs,
            self.config.tabular_features_config
        )
        mlp_outputs = self.mlp(concatenated_inputs)
        logits = self.head(mlp_outputs)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)

        # at the bare minimum, we need to return loss as well as logits
        # for both training and evaluation
        return loss, F.softmax(logits, dim=-1)

    def concatenate_tabular_inputs(self, tabular_inputs, tabular_features_config):
        numerical_inputs = []
        categorical_inputs = []
        for name, config in tabular_features_config.items():
            if config["dtype"] == "categorical":
                feature_name = f"{name}_embedding"
                share_embedding = config.get("share_embedding")
                if share_embedding:
                    feature_name = f"{share_embedding}_embedding"

                embedding = self.embeddings[feature_name]
                features = tabular_inputs[name].type(torch.long)
                embed = embedding(features)
                categorical_inputs.append(embed)
            elif config["dtype"] == "numerical":
                features = tabular_inputs[name].type(torch.float32)
                if len(features.shape) == 1:
                    features = features.unsqueeze(dim=1)

                numerical_inputs.append(features)

        if len(numerical_inputs) > 0:
            numerical_inputs = torch.cat(numerical_inputs, dim=-1)

        categorical_inputs.append(numerical_inputs)
        concatenated_inputs = torch.cat(categorical_inputs, dim=-1)
        return concatenated_inputs

    def init_tabular_parameters(self, tabular_features_config):
        embeddings = {}
        output_dim = 0
        for name, config in tabular_features_config.items():
            if config["dtype"] == "categorical":
                feature_name = f"{name}_embedding"
                # create new embedding layer for categorical features if share_embedding is None
                share_embedding = config.get("share_embedding")
                if share_embedding:
                    share_embedding_config = tabular_features_config[share_embedding]
                    embedding_size = share_embedding_config["embedding_size"]
                else:
                    embedding_size = config["embedding_size"]
                    embedding = nn.Embedding(config["vocab_size"], embedding_size)
                    embeddings[feature_name] = embedding

                output_dim += embedding_size
            elif config["dtype"] == "numerical":
                output_dim += 1

        return nn.ModuleDict(embeddings), output_dim
In [15]:
mlp_config = [
    {
        "hidden_size": 1024,
        "dropout_p": 0.1,
        "activation_function": "ReLU",
        "normalization_function": "LayerNorm"
    },
    {
        "hidden_size": 512,
        "dropout_p": 0.1,
        "activation_function": "ReLU",
        "normalization_function": "LayerNorm"
    },
    {
        "hidden_size": 256,
        "dropout_p": 0.1,
        "activation_function": "ReLU",
        "normalization_function": "LayerNorm"
    }
]
config = TabularModelConfig(tabular_features_config, mlp_config)
model = TabularModel(config)
print("# of parameters: ", model.num_parameters())
model
# of parameters:  51308904
Out[15]:
TabularModel(
  (embeddings): ModuleDict(
    (C1_embedding): Embedding(179728, 64)
    (C2_embedding): Embedding(12325, 32)
    (C3_embedding): Embedding(11780, 32)
    (C4_embedding): Embedding(4156, 32)
    (C5_embedding): Embedding(10576, 32)
    (C6_embedding): Embedding(3, 3)
    (C7_embedding): Embedding(5850, 32)
    (C8_embedding): Embedding(1139, 32)
    (C9_embedding): Embedding(38, 16)
    (C10_embedding): Embedding(136421, 64)
    (C11_embedding): Embedding(33820, 32)
    (C12_embedding): Embedding(34916, 32)
    (C13_embedding): Embedding(10, 10)
    (C14_embedding): Embedding(1841, 32)
    (C15_embedding): Embedding(5445, 32)
    (C16_embedding): Embedding(56, 16)
    (C17_embedding): Embedding(4, 4)
    (C18_embedding): Embedding(615, 32)
    (C19_embedding): Embedding(14, 14)
    (C20_embedding): Embedding(187780, 64)
    (C21_embedding): Embedding(80020, 32)
    (C22_embedding): Embedding(165496, 64)
    (C23_embedding): Embedding(29741, 9)
    (C24_embedding): Embedding(7693, 32)
    (C25_embedding): Embedding(54, 16)
    (C26_embedding): Embedding(33, 16)
  )
  (mlp): Sequential(
    (0): Linear(in_features=789, out_features=1024, bias=True)
    (1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (2): ReLU()
    (3): Dropout(p=0.1, inplace=False)
    (4): Linear(in_features=1024, out_features=512, bias=True)
    (5): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (6): ReLU()
    (7): Dropout(p=0.1, inplace=False)
    (8): Linear(in_features=512, out_features=256, bias=True)
    (9): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (10): ReLU()
    (11): Dropout(p=0.1, inplace=False)
  )
  (head): Linear(in_features=256, out_features=2, bias=True)
)
In [16]:
# a quick test on a sample batch to ensure model forward pass runs
output = model(**batch)
output
Out[16]:
(tensor(0.4340, grad_fn=<NllLossBackward0>),
 tensor([[0.6561, 0.3439],
         [0.6399, 0.3601]], grad_fn=<SoftmaxBackward0>))

Rest of the code block defines boilerplate code for leveraging huggingface transformer's Trainer, as well as defining a compute_metrics function for calculating standard binary classification related metrics.

In [17]:
def compute_metrics(eval_preds, round_digits: int = 3):
    y_pred, y_true = eval_preds
    y_score = y_pred[:, 1]

    log_loss = round(metrics.log_loss(y_true, y_score), round_digits)
    roc_auc = round(metrics.roc_auc_score(y_true, y_score), round_digits)
    pr_auc = round(metrics.average_precision_score(y_true, y_score), round_digits)
    return {
        'roc_auc': roc_auc,
        'pr_auc': pr_auc,
        'log_loss': log_loss
    }
In [18]:
os.environ["DISABLE_MLFLOW_INTEGRATION"] = "TRUE"
training_args = TrainingArguments(
    output_dir="tabular",
    num_train_epochs=5,
    learning_rate=0.001,
    per_device_train_batch_size=128,
    gradient_accumulation_steps=2,
    fp16=True,
    lr_scheduler_type="constant",
    evaluation_strategy="steps",
    eval_steps=1000,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=2,
    do_train=True,
    # we are collecting all tabular features into a single entry
    # tabular_inputs during collate function, this is to prevent
    # huggingface trainer from removing these features while processing
    # our dataset
    remove_unused_columns=False
)

trainer = Trainer(
    model,
    args=training_args,
    data_collator=tabular_collate_fn,
    train_dataset=dataset_train,
    eval_dataset=dataset_test,
    compute_metrics=compute_metrics
)

# for this dataset, Roc-AUC typically falls in the range of 0.725 - 0.727
train_output = trainer.train()
/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(
Could not estimate the number of tokens of the input, floating-point operations will not be computed
[9045/9045 19:43, Epoch 5/5]
Step Training Loss Validation Loss Loss Roc Auc Pr Auc Runtime Samples Per Second Steps Per Second
1000 0.205400 0.134000 0.133939 0.700000 0.073000 75.662400 1321.660000 165.208000
2000 0.200200 0.146000 0.146196 0.718000 0.079000 75.984300 1316.062000 164.508000
3000 0.193100 0.137000 0.137163 0.723000 0.084000 75.452500 1325.337000 165.667000
4000 0.192400 0.132000 0.131718 0.719000 0.086000 75.892400 1317.654000 164.707000
5000 0.189100 0.131000 0.130899 0.727000 0.085000 74.677300 1339.095000 167.387000
6000 0.184700 0.128000 0.128440 0.720000 0.085000 73.770400 1355.558000 169.445000
7000 0.185800 0.132000 0.131577 0.721000 0.086000 76.291500 1310.761000 163.845000
8000 0.179900 0.135000 0.134822 0.717000 0.082000 74.252700 1346.752000 168.344000
9000 0.178100 0.128000 0.127899 0.720000 0.083000 73.945800 1352.341000 169.043000

In [19]:
# we can also leverage .predict for quickly performing batch prediction on a given
# input dataset
prediction_output = trainer.predict(dataset_test)
prediction_output
Out[19]:
PredictionOutput(predictions=array([[0.99899954, 0.0010005 ],
       [0.99010193, 0.00989806],
       [0.9948001 , 0.00519988],
       ...,
       [0.9937588 , 0.00624126],
       [0.83815056, 0.16184942],
       [0.869991  , 0.13000904]], dtype=float32), label_ids=array([0, 0, 0, ..., 0, 0, 0]), metrics={'test_loss': 0.13327662646770477, 'test_roc_auc': 0.719, 'test_pr_auc': 0.085, 'test_log_loss': 0.133, 'test_runtime': 73.7289, 'test_samples_per_second': 1356.321, 'test_steps_per_second': 169.54})

End Notes

In this post, we walked through a baseline workflow for training tabular datasets using deep neural networks in PyTorch. Many works have cited the success of applying deep neural networks as part of their core recommendation stack, e.g. Youtube Recommendation [5] or Airbnb Search [6] [7]. Apart from making the model bigger/deeper for improving performance, we'll briefly touch upon some of their key learnings to conclude this article.

Heterogeneous Signals

Compared to matrix factorization based algorithms in collaborative filtering, it's easier to add diverse set of signals into the model.

For instance, in the context of Youtube recommendation:

  • Recommendation system particularly benefit from specialized features that capture historical behavior. This includes user's previous interaction with the item, how many videos has the user watched from a specific channel? Time since the user last watched a video on a particular topic. Apart from numerical features that are hand crafted, we can also include user's watch or search history as variable length sequence and have it mapped into a dense embedding representation.
  • In a retrieval + ranking staged system, candidate generation information can be propagated into ranking phase as features. e.g. which sources nominated a candidate and its assigned score.
  • Categorical variables' embedding can be shared. e.g. a single video id embedding can be leveraged across various features (impression video id, last video id watched by the user, seed video id for the recommendation).
  • While popular tree based models are invariant to scaling of individual features, neural networks are quite sensitive to them. Therefore, Normalizing continuous features is a must. Normalization can be done via Min/Max scaling, log-transformation, or standard normalization.
  • Recommendation system often exhibit some form of bias towards the past, as they are trained using prior data. For Youtube, adding a content's age on a platform allows the model to represent a video's time dependent behavior.

e.g. For Airbnb search:

  • Domain knowledge proves to be valuable in feature normalization. e.g. When dealing with geo location represented by latitude and longitude, instead of using the raw coordinates, we can calculate the offset from map's center displayed to the user. This allows the model to learn distance based global properties rather than specifics of individual geography. For learning local geography, a new categorical feature is created by taking city specified in the query, and the level 12 S2 cell for a listing. A hashing function then maps these two values (city and S2 cells) into an integer. For example, given the query "San Francisco" and a listing near the Embarcadero (S2 cell 539058204), hashing {"San Francisco", 539058204} -> 71829521 creates this categorical feature.
  • Position bias is also a notable topic in literature. This bias emerges when historical logs are used for training subsequent models. Introducing position as a feature while regularizing by dropout was proposed as strategies for mitigating this bias.

Reference

  • [1] Criteo 1TB Click Logs dataset
  • [2] Kaggle Competition - Display Advertising Challenge
  • [3] Transformers Doc - Sharing custom models
  • [4] Blog: An Introduction to Deep Learning for Tabular Data
  • [5] Paul Covington, Jay Adams, Emre Sargin - Deep Neural Networks for YouTube Recommendations (2016)
  • [6] Malay Haldar, Mustafa Abdool, Prashant Ramanathan, Tao Xu, Shulin Yang, Huizhong Duan, Qing Zhang, Nick Barrow-Williams, Bradley C. Turnbull, Brendan M. Collins, Thomas Legrand - Applying Deep Learning To Airbnb Search (2018)
  • [7] Malay Haldar, Mustafa Abdool, Prashant Ramanathan, Tyler Sax, Lanbo Zhang, Aamir Mansawala, Shulin Yang, Bradley Turnbull, Junshuo Liao - Improving Deep Learning For Airbnb Search (2020)