None
# code for loading notebook's format
import os
# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', 'notebook_format'))
from formats import load_style
load_style(css_style='custom2.css', plot_style=False)
os.chdir(path)
%load_ext watermark
%load_ext autoreload
%autoreload 2
import os
import yaml
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import sklearn.metrics as metrics
from torch.nn import functional as F
from torch.utils.data import DataLoader
from datasets import (
Dataset,
load_dataset,
disable_progress_bar
)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder, MinMaxScaler
from transformers import (
PretrainedConfig,
PreTrainedModel,
TrainingArguments,
Trainer
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
%watermark -a 'Ethen' -d -t -v -u -p transformers,datasets,torch,numpy,pandas,sklearn
While deep learning's achievements are often highlighted in areas like computer vision and natural language processing, a lesser-discussed yet potent application involves applying deep learning to tabular data.
A key technique to maximize deep learning's potential with tabular data involves using embeddings for categorical variables [4]. This means representing categories in a lower-dimensional numeric space, capturing intricate relationships between them. For instance, this could reveal geographic connections between high-cardinality categorical features like zip codes, without explicit guidance. Even for continuous features such as days of the week, it's still worth exploring the potential advantages of treating them as categorical features and utilizing embeddings.
Furthermore, embeddings offer benefits beyond their initial use. Once trained, these embeddings can be employed in other contexts. For example, they can serve as features for tree-based models, granting them the enriched knowledge gleaned from deep learning. This cross-application of embeddings underscores their versatility and their ability to enhance various modeling techniques.
In this article, we'll be looking at some bare minimum steps for training a self-defined deep learning model and training it using huggingface Trainer.
We'll be using a downsampled criteo dataset, which originated from a Kaggle competition [2]. Though after the competition ended, those original data files became unavailable on the platform. We turned to an alternative source for downloading a similar dataset [1]. Each row corresponds to a display ad served by Criteo. Positive (clicked) and negatives (non-clicked) examples have both been subsampled at different rates in order to reduce the dataset size. Fields in this dataset includes:
Unfortunately, the meanings of these features aren't disclosed.
Note, there are many ways to implement a data preprocessing step, the baseline approach we'll be performing here is to:
# one time code for creating a sampled criteo dataset
import gzip
import pandas as pd
def parse_criteo_data(gzip_file: str, num_records: int, output_path: str):
"""
Parse gzipped criteo dataset and save it into a tabular parquet format.
"""
columns = [
'label', 'I1', 'I2', 'I3', 'I4', 'I5', 'I6', 'I7', 'I8', 'I9', 'I10',
'I11', 'I12', 'I13', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8',
'C9', 'C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18',
'C19', 'C20', 'C21', 'C22', 'C23', 'C24', 'C25', 'C26'
]
dtype = {}
for col in columns:
if "C" in col:
dtype[col] = "string"
elif col == "label":
dtype[col] = "int"
else:
dtype[col] = "float"
lines = []
with gzip.open(gzip_file, 'r') as f_in:
for i in range(num_records):
line = f_in.readline()
line = str(line, encoding="utf-8")
line = line.strip().split("\t")
lines.append(line)
df = pd.DataFrame(lines, columns=columns)
df = df.replace("", None)
df = df.astype(dtype)
df.to_parquet(output_path, index=False)
gzip_file = "day_0.gz"
num_records = 1000000
output_path = "criteo_sampled.parquet"
parse_criteo_data(gzip_file, num_records, output_path)
input_path = "criteo_data/criteo_sampled.parquet"
df = pd.read_parquet(input_path)
print(df.shape)
df.head()
sparse_features = ['C' + str(i) for i in range(1, 27)]
dense_features = ['I' + str(i) for i in range(1, 14)]
feature_names = dense_features + sparse_features
df[sparse_features] = df[sparse_features].fillna('-1')
df[dense_features] = df[dense_features].fillna(0)
# label encoding for categorical/sparse features
# and scaling for numerical/dense features
ordinal_encoder = OrdinalEncoder(min_frequency=30)
df[sparse_features] = ordinal_encoder.fit_transform(df[sparse_features])
min_max_scaler = MinMaxScaler(feature_range=(0, 1))
df[dense_features] = min_max_scaler.fit_transform(df[dense_features])
df.head()
df["label"].value_counts()
def downsample_negative(df: pd.DataFrame, frac: float = 0.5, random_state: int = 1234):
"""Given a binary classification task with 0/1 labels, downsample negative class (class 0) with the
specified fraction parameter.
"""
df_majority = df[df["label"] == 0]
df_minority = df[df["label"] == 1]
df_downsampled_majority = df_majority.sample(frac=frac, random_state=random_state)
df_downsampled = pd.concat([df_downsampled_majority, df_minority])
# shuffle the combined data frame
df_downsampled = df_downsampled.sample(frac=1, random_state=random_state).reset_index(drop=True)
return df_downsampled
df_train, df_test = train_test_split(df, test_size=0.1, random_state=1234, stratify=df["label"])
df_test = df_test.reset_index(drop=True)
df_train_downsampled = downsample_negative(df_train)
dataset_train = Dataset.from_pandas(df_train_downsampled)
dataset_test = Dataset.from_pandas(df_test)
print(dataset_train)
dataset_train[0]
We'll specify a config mapping for tabular features that we'll be using across our batch collate function as well as model. This config mapping have features we wish to leverage as keys, and different value/enum specifying whether the field is numerical or categorical type. This will be beneficial to inform our model about the embedding size required for a categorical type as well as how many numerical fields are there to initiate the dense/feed forward layers.
# demonstrating the functionality with 1 numerical and 1 categorical feature
tabular_features_config = {
"I1": {
"dtype": "numerical",
},
"C1": {
"dtype": "categorical",
"vocab_size": len(ordinal_encoder.categories_[0]),
"embedding_size": 32
}
}
def tabular_collate_fn(batch):
"""
Use in conjunction with Dataloader's collate_fn for tabular data.
Returns
-------
batch : dict
Dictionary with two primary keys: tabular_inputs, and labels. Tabular
inputs is a nested field, where each element is a feature_name -> float tensor
mapping. e.g.
{
'tabular_inputs': {'I1': tensor([0., 0.]), 'C1': tensor([ 888., 1313.])},
'labels': tensor([0, 0])
}
"""
labels = []
tabular_inputs = {}
for example in batch:
label = example["label"]
labels.append(label)
for name in tabular_features_config:
feature = example[name]
if name not in tabular_inputs:
tabular_inputs[name] = [feature]
else:
tabular_inputs[name].append(feature)
for name in tabular_inputs:
tabular_inputs[name] = torch.FloatTensor(tabular_inputs[name])
batch = {
"tabular_inputs": tabular_inputs,
"labels": torch.LongTensor(labels)
}
return batch
data_loader = DataLoader(dataset_train, batch_size=2, collate_fn=tabular_collate_fn)
batch = next(iter(data_loader))
batch
# specify all the features in config.yaml to prevent clunky display
# we'll need to update vocabulary size for each categorical features
# if we were to use a different dataset
# for category in ordinal_encoder.categories_:
# print(len(category))
with open("features_config.yaml", "r") as f_in:
config = yaml.safe_load(f_in)
tabular_features_config = config["tabular_features_config"]
data_loader = DataLoader(dataset_train, batch_size=2, collate_fn=tabular_collate_fn)
batch = next(iter(data_loader))
batch
Our model architecture mainly involves: Converting categorical features into a low dimensonal embedding, these embedding outputs are then concatenated with rest of the dense features before feeding them into subsequent feed forward layers.
def get_mlp_layers(input_dim: int, mlp_config):
"""
Construct MLP, a.k.a. Feed forward layers based on input config.
Parameters
----------
input_dim :
Input dimension for the first layer.
mlp_config : list of dictionary with mlp spec.
An example is shown below, the only mandatory parameter is hidden size.
```
[
{
"hidden_size": 1024,
"dropout_p": 0.1,
"activation_function": "ReLU",
"activation_function_kwargs": {},
"normalization_function": "LayerNorm"
"normalization_function_kwargs": {"eps": 1e-05}
}
]
```
Returns
-------
nn.Sequential :
Sequential layer converted from input mlp_config. If mlp_config
is None, then this returned value will also be None.
current_dim :
Dimension for the last layer.
"""
if mlp_config is None:
return None, input_dim
layers = []
current_dim = input_dim
for config in mlp_config:
hidden_size = config["hidden_size"]
dropout_p = config.get("dropout_p", 0.0)
activation_function = config.get("activation_function")
activation_function_kwargs = config.get("activation_function_kwargs", {})
normalization_function = config.get("normalization_function")
normalization_function_kwargs = config.get("normalization_function_kwargs", {})
linear = nn.Linear(current_dim, hidden_size)
layers.append(linear)
if normalization_function:
normalization = getattr(nn, normalization_function)(hidden_size, **normalization_function_kwargs)
layers.append(normalization)
if activation_function:
activation = getattr(nn, activation_function)(**activation_function_kwargs)
layers.append(activation)
dropout = nn.Dropout(p=dropout_p)
layers.append(dropout)
current_dim = hidden_size
return nn.Sequential(*layers), current_dim
The next code block involves defining a config and model class following huggingface transformer's class structure [3]. This allows us to leverage its Trainer class for training and evaluating our models instead of writing custom training loops.
class TabularModelConfig(PretrainedConfig):
model_type = "tabular"
def __init__(
self,
tabular_features_config=None,
mlp_config=None,
num_labels=2,
**kwargs
):
super().__init__(**kwargs)
self.tabular_features_config = tabular_features_config
self.mlp_config = mlp_config
self.num_labels = num_labels
class TabularModel(PreTrainedModel):
config_class = TabularModelConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.embeddings, output_dim = self.init_tabular_parameters(config.tabular_features_config)
self.mlp, output_dim = get_mlp_layers(output_dim, config.mlp_config)
self.head = nn.Linear(output_dim, config.num_labels)
def forward(self, tabular_inputs, labels=None):
concatenated_inputs = self.concatenate_tabular_inputs(
tabular_inputs,
self.config.tabular_features_config
)
mlp_outputs = self.mlp(concatenated_inputs)
logits = self.head(mlp_outputs)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, labels)
# at the bare minimum, we need to return loss as well as logits
# for both training and evaluation
return loss, F.softmax(logits, dim=-1)
def concatenate_tabular_inputs(self, tabular_inputs, tabular_features_config):
numerical_inputs = []
categorical_inputs = []
for name, config in tabular_features_config.items():
if config["dtype"] == "categorical":
feature_name = f"{name}_embedding"
share_embedding = config.get("share_embedding")
if share_embedding:
feature_name = f"{share_embedding}_embedding"
embedding = self.embeddings[feature_name]
features = tabular_inputs[name].type(torch.long)
embed = embedding(features)
categorical_inputs.append(embed)
elif config["dtype"] == "numerical":
features = tabular_inputs[name].type(torch.float32)
if len(features.shape) == 1:
features = features.unsqueeze(dim=1)
numerical_inputs.append(features)
if len(numerical_inputs) > 0:
numerical_inputs = torch.cat(numerical_inputs, dim=-1)
categorical_inputs.append(numerical_inputs)
concatenated_inputs = torch.cat(categorical_inputs, dim=-1)
return concatenated_inputs
def init_tabular_parameters(self, tabular_features_config):
embeddings = {}
output_dim = 0
for name, config in tabular_features_config.items():
if config["dtype"] == "categorical":
feature_name = f"{name}_embedding"
# create new embedding layer for categorical features if share_embedding is None
share_embedding = config.get("share_embedding")
if share_embedding:
share_embedding_config = tabular_features_config[share_embedding]
embedding_size = share_embedding_config["embedding_size"]
else:
embedding_size = config["embedding_size"]
embedding = nn.Embedding(config["vocab_size"], embedding_size)
embeddings[feature_name] = embedding
output_dim += embedding_size
elif config["dtype"] == "numerical":
output_dim += 1
return nn.ModuleDict(embeddings), output_dim
mlp_config = [
{
"hidden_size": 1024,
"dropout_p": 0.1,
"activation_function": "ReLU",
"normalization_function": "LayerNorm"
},
{
"hidden_size": 512,
"dropout_p": 0.1,
"activation_function": "ReLU",
"normalization_function": "LayerNorm"
},
{
"hidden_size": 256,
"dropout_p": 0.1,
"activation_function": "ReLU",
"normalization_function": "LayerNorm"
}
]
config = TabularModelConfig(tabular_features_config, mlp_config)
model = TabularModel(config)
print("# of parameters: ", model.num_parameters())
model
# a quick test on a sample batch to ensure model forward pass runs
output = model(**batch)
output
Rest of the code block defines boilerplate code for leveraging huggingface transformer's Trainer, as well as defining a compute_metrics
function for calculating standard binary classification related metrics.
def compute_metrics(eval_preds, round_digits: int = 3):
y_pred, y_true = eval_preds
y_score = y_pred[:, 1]
log_loss = round(metrics.log_loss(y_true, y_score), round_digits)
roc_auc = round(metrics.roc_auc_score(y_true, y_score), round_digits)
pr_auc = round(metrics.average_precision_score(y_true, y_score), round_digits)
return {
'roc_auc': roc_auc,
'pr_auc': pr_auc,
'log_loss': log_loss
}
os.environ["DISABLE_MLFLOW_INTEGRATION"] = "TRUE"
training_args = TrainingArguments(
output_dir="tabular",
num_train_epochs=5,
learning_rate=0.001,
per_device_train_batch_size=128,
gradient_accumulation_steps=2,
fp16=True,
lr_scheduler_type="constant",
evaluation_strategy="steps",
eval_steps=1000,
save_strategy="steps",
save_steps=1000,
save_total_limit=2,
do_train=True,
# we are collecting all tabular features into a single entry
# tabular_inputs during collate function, this is to prevent
# huggingface trainer from removing these features while processing
# our dataset
remove_unused_columns=False
)
trainer = Trainer(
model,
args=training_args,
data_collator=tabular_collate_fn,
train_dataset=dataset_train,
eval_dataset=dataset_test,
compute_metrics=compute_metrics
)
# for this dataset, Roc-AUC typically falls in the range of 0.725 - 0.727
train_output = trainer.train()
# we can also leverage .predict for quickly performing batch prediction on a given
# input dataset
prediction_output = trainer.predict(dataset_test)
prediction_output
In this post, we walked through a baseline workflow for training tabular datasets using deep neural networks in PyTorch. Many works have cited the success of applying deep neural networks as part of their core recommendation stack, e.g. Youtube Recommendation [5] or Airbnb Search [6] [7]. Apart from making the model bigger/deeper for improving performance, we'll briefly touch upon some of their key learnings to conclude this article.
Heterogeneous Signals
Compared to matrix factorization based algorithms in collaborative filtering, it's easier to add diverse set of signals into the model.
For instance, in the context of Youtube recommendation:
e.g. For Airbnb search: