None
import os
# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', 'notebook_format'))
from formats import load_style
load_style(css_style='custom2.css', plot_style=False)
# 1. magic for inline plot
# 2. magic to print version
# 3. magic so that the notebook will reload external python modules
# 4. magic to enable retina (high resolution) plots
# https://gist.github.com/minrk/3301035
%matplotlib inline
%load_ext watermark
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'
import os
import time
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import trange
from torch import optim
from torch.utils.data import DataLoader
from datasets import load_dataset, DatasetDict, Dataset
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer,
DataCollatorWithPadding
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
%watermark -a 'Ethen' -d -u -p datasets,transformers,torch,tokenizers,numpy,pandas,matplotlib
In this article, we'll be going over two main things:
Finetuning pre-trained models on downstream tasks has been increasingly popular these days, this notebook documents the findings on these model's calibration. Calibration in this context means does the model's predicted score reflects true probability. If the reader is not familiar with model calibration 101, there is a separate notebook [nbviewer][html] that covers this topic. Reading up till the "Measuring Calibration" section should suffice.
dataset_dict = load_dataset("quora")
dataset_dict
dataset_dict['train'][0]
test_size = 0.1
val_size = 0.1
dataset_dict_test = dataset_dict['train'].train_test_split(test_size=test_size)
dataset_dict_train_val = dataset_dict_test['train'].train_test_split(test_size=val_size)
dataset_dict = DatasetDict({
"train": dataset_dict_train_val["train"],
"val": dataset_dict_train_val["test"],
"test": dataset_dict_test["test"]
})
dataset_dict
We won't be going over the details of the pre-trained tokenizer or model and only load a pre-trained one available from the huggingface model repository.
# https://huggingface.co/transformers/model_doc/mobilebert.html
pretrained_model_name_or_path = "google/mobilebert-uncased"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
tokenizer
We can feed our tokenizer directly with a pair of sentences.
encoded_input = tokenizer(
'What is the step by step guide to invest in share market in india?',
'What is the step by step guide to invest in share market?'
)
encoded_input
Decoding the tokenized inputs, this model's tokenizer adds some special tokens such as, [SEP]
, that is used to indicate which token belongs to which segment/pair.
tokenizer.decode(encoded_input["input_ids"])
The proprocessing step will be task specific, if we happen to be using another dataset, this function needs to be modified accordingly.
def tokenize_fn(examples):
labels = [int(label) for label in examples['is_duplicate']]
texts = [question['text'] for question in examples['questions']]
texts1 = [text[0] for text in texts]
texts2 = [text[1] for text in texts]
tokenized_examples = tokenizer(texts1, texts2)
tokenized_examples['labels'] = labels
return tokenized_examples
dataset_dict_tokenized = dataset_dict.map(
tokenize_fn,
batched=True,
num_proc=8,
remove_columns=['is_duplicate', 'questions']
)
dataset_dict_tokenized
dataset_dict_tokenized['train'][0]
Having preprocessed our raw dataset, for our text classification task, we use AutoModelForSequenceClassification
class to load the pre-trained model, the only other argument we need to specify is number of class/label our text classification task has. Upon instantiating this model for the first time, we'll see some warnings generated, telling us we should fine tune this model on our down stream tasks before using it.
model_checkpoint = 'text_classification'
num_labels = 2
# we'll save the model after fine tuning it once, so we can skip the fine tuning part during
# the second round if we detect that we already have one available
if os.path.isdir(model_checkpoint):
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint).to(device)
else:
model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_name_or_path,
num_labels=num_labels
).to(device)
print('# of parameters: ', model.num_parameters())
data_collator = DataCollatorWithPadding(tokenizer, padding=True)
data_collator
We can perform all sorts of hyper parameter tuning on the fine tuning step, here we'll pick some default parameters for illustration purposes.
batch_size = 64
args = TrainingArguments(
"quora_fine_tuned",
learning_rate=1e-4,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=2,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True
)
trainer = Trainer(
model,
args,
data_collator=data_collator,
train_dataset=dataset_dict_tokenized["train"],
eval_dataset=dataset_dict_tokenized['val']
)
if not os.path.isdir(model_checkpoint):
trainer.train()
model.save_pretrained(model_checkpoint)
class SoftmaxModule(nn.Module):
"""
Add a softmax layer on top the base model. Note this does not necessarily
mean the output score is a well-calibrated probability.
"""
def __init__(self, model_path: str):
super().__init__()
self.model_path = model_path
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
def forward(self, input_ids, attention_mask, token_type_ids):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
)
scores = nn.functional.softmax(outputs.logits, dim=-1)[:, 1]
return scores
softmax_module = SoftmaxModule(model_checkpoint).to(device)
softmax_module.eval()
print('# of parameters: ', sum(p.numel() for p in softmax_module.parameters() if p.requires_grad))
We define some helper functions to generate predictions for our dataset, store the predicted score and label into a pandas DataFrame.
def predict(model, examples, round_digits: int = 5):
input_ids = examples['input_ids'].to(device)
attention_mask = examples['attention_mask'].to(device)
token_type_ids = examples['token_type_ids'].to(device)
batch_labels = examples['labels'].detach().cpu().numpy().tolist()
model.eval()
with torch.no_grad():
batch_output = model(input_ids, attention_mask, token_type_ids)
batch_scores = np.round(batch_output.detach().cpu().numpy(), round_digits).tolist()
return batch_scores, batch_labels
def predict_data_loader(model, data_loader: DataLoader) -> pd.DataFrame:
scores = []
labels = []
for examples in data_loader:
batch_scores, batch_labels = predict(model, examples)
scores += batch_scores
labels += batch_labels
df_predictions = pd.DataFrame.from_dict({'scores': scores, 'labels': labels})
return df_predictions
data_collator = DataCollatorWithPadding(tokenizer, padding=True)
data_loader = DataLoader(dataset_dict_tokenized['test'], collate_fn=data_collator, batch_size=128)
start = time.time()
df_predictions = predict_data_loader(softmax_module, data_loader)
end = time.time()
print('elapsed: ', end - start)
print(df_predictions.shape)
df_predictions.head()
Temperature Scaling is a post-processing technique that was proposed to improve upon the calibration error, but specifically designed for deep learning. It works by dividing the logits (output of the layer right before the final softmax layer) by a learned scalar parameter.
\begin{align} \text{softmax} = \frac{e^{(z/T)}}{\sum_i e^{(z_i/T)}} \end{align}where $z$ is the logit, and $T$ is the learned temperature scaling parameter. We learn this parameter on a validation set, where $T$ is chosen to minimize negative log likelihood. As we can imagine, with $T \ge 1$, it lowers the predicted score across all classes, making the model less confident about its predictions but does not change the model's predicted maximum class.
The benefit of this approach is mainly two folds:
class TemperatureScalingCalibrationModule(nn.Module):
def __init__(self, model_path: str):
super().__init__()
self.model_path = model_path
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
# the single temperature scaling parameter, the initialization value doesn't
# seem to matter that much based on some ad-hoc experimentation
self.temperature = nn.Parameter(torch.ones(1))
def forward(self, input_ids, attention_mask, token_type_ids):
"""forward method that returns softmax-ed confidence scores."""
outputs = self.forward_logit(input_ids, attention_mask, token_type_ids)
scores = nn.functional.softmax(outputs, dim=-1)[:, 1]
return scores
def forward_logit(self, input_ids, attention_mask, token_type_ids):
"""forward method that returns logits, to be used with cross entropy loss."""
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
).logits
return outputs / self.temperature
def fit(self, dataset_tokenized, n_epochs: int = 3, batch_size: int = 64, lr: float = 0.01):
"""fits the temperature scaling parameter."""
data_collator = DataCollatorWithPadding(tokenizer, padding=True)
data_loader = DataLoader(dataset_tokenized, collate_fn=data_collator, batch_size=batch_size)
self.freeze_base_model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(self.parameters(), lr=lr)
for epoch in trange(n_epochs):
for examples in data_loader:
labels = examples['labels'].long().to(device)
input_ids = examples['input_ids'].to(device)
attention_mask = examples['attention_mask'].to(device)
token_type_ids = examples['token_type_ids'].to(device)
# standard step to perform the forward and backward step
self.zero_grad()
predict_proba = self.forward_logit(input_ids, attention_mask, token_type_ids)
loss = criterion(predict_proba, labels)
loss.backward()
optimizer.step()
return self
def freeze_base_model(self):
"""remember to freeze base model's parameters when training temperature scaler"""
self.model.eval()
for parameter in self.model.parameters():
parameter.requires_grad = False
return self
calibration_module = TemperatureScalingCalibrationModule(model_checkpoint).to(device)
calibration_module.fit(dataset_dict_tokenized['val'])
print('# of parameters: ', sum(p.numel() for p in calibration_module.parameters() if p.requires_grad))
calibration_module.temperature
data_collator = DataCollatorWithPadding(tokenizer, padding=True)
data_loader = DataLoader(dataset_dict_tokenized['test'], collate_fn=data_collator, batch_size=128)
start = time.time()
df_calibrated_predictions = predict_data_loader(calibration_module, data_loader)
end = time.time()
print('elapsed: ', end - start)
print(df_calibrated_predictions.shape)
df_calibrated_predictions.head()
Observations:
from calibration_module.utils import compute_calibration_summary
eval_dict = {
f'{model_checkpoint}': df_predictions,
f'{model_checkpoint}_calibrated': df_calibrated_predictions
}
# change default style figure and font size
plt.rcParams['figure.figsize'] = 12, 12
plt.rcParams['font.size'] = 12
n_bins = 20
df_result = compute_calibration_summary(eval_dict, label_col='labels', score_col='scores', n_bins=n_bins)
df_result
There're other works [3] that studies calibration effect for state of the art models. Although it's mainly for image based models, their claim is that model size and pretraining amount don't fully account for the differences in calibration across different models, but primary factor seems to be on model architecture, or more explicitly models that rely on attention based mechanism are found to be better calibrated compared to convolution based mechanism.