None deeplearning_prob_calibration
In [1]:
import os

# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', 'notebook_format'))

from formats import load_style
load_style(css_style='custom2.css', plot_style=False)
Out[1]:
In [2]:
# 1. magic for inline plot
# 2. magic to print version
# 3. magic so that the notebook will reload external python modules
# 4. magic to enable retina (high resolution) plots
# https://gist.github.com/minrk/3301035
%matplotlib inline
%load_ext watermark
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

import os
import time
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import trange
from torch import optim
from torch.utils.data import DataLoader
from datasets import load_dataset, DatasetDict, Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

%watermark -a 'Ethen' -d -u -p datasets,transformers,torch,tokenizers,numpy,pandas,matplotlib
Author: Ethen

Last updated: 2022-08-25

datasets    : 2.2.2
transformers: 4.19.2
torch       : 1.11.0
tokenizers  : 0.12.1
numpy       : 1.21.6
pandas      : 1.3.5
matplotlib  : 3.4.3

Deep Learning Model Calibration with Temperature Scaling

In this article, we'll be going over two main things:

  • Process of finetuning a pre-trained BERT model towards a text classification task, more specificially, the Quora Question Pairs challenge.
  • Process of evaluating model calibration and improving upong calibration error using temperature scaling [2].

Finetuning pre-trained models on downstream tasks has been increasingly popular these days, this notebook documents the findings on these model's calibration. Calibration in this context means does the model's predicted score reflects true probability. If the reader is not familiar with model calibration 101, there is a separate notebook [nbviewer][html] that covers this topic. Reading up till the "Measuring Calibration" section should suffice.

In [3]:
dataset_dict = load_dataset("quora")
dataset_dict
Using custom data configuration default
Reusing dataset quora (/home/mingyuliu/.cache/huggingface/datasets/quora/default/0.0.0/36ba4cd42107f051a158016f1bea6ae3f4685c5df843529108a54e42d86c1e04)
  0%|          | 0/1 [00:00<?, ?it/s]
Out[3]:
DatasetDict({
    train: Dataset({
        features: ['questions', 'is_duplicate'],
        num_rows: 404290
    })
})
In [4]:
dataset_dict['train'][0]
Out[4]:
{'questions': {'id': [1, 2],
  'text': ['What is the step by step guide to invest in share market in india?',
   'What is the step by step guide to invest in share market?']},
 'is_duplicate': False}
In [5]:
test_size = 0.1
val_size = 0.1
dataset_dict_test = dataset_dict['train'].train_test_split(test_size=test_size)
dataset_dict_train_val = dataset_dict_test['train'].train_test_split(test_size=val_size)

dataset_dict = DatasetDict({
    "train": dataset_dict_train_val["train"],
    "val": dataset_dict_train_val["test"],
    "test": dataset_dict_test["test"]
})
dataset_dict
Out[5]:
DatasetDict({
    train: Dataset({
        features: ['is_duplicate', 'questions'],
        num_rows: 327474
    })
    val: Dataset({
        features: ['is_duplicate', 'questions'],
        num_rows: 36387
    })
    test: Dataset({
        features: ['is_duplicate', 'questions'],
        num_rows: 40429
    })
})

Tokenizer

We won't be going over the details of the pre-trained tokenizer or model and only load a pre-trained one available from the huggingface model repository.

In [6]:
# https://huggingface.co/transformers/model_doc/mobilebert.html
pretrained_model_name_or_path = "google/mobilebert-uncased"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
tokenizer
Out[6]:
PreTrainedTokenizerFast(name_or_path='google/mobilebert-uncased', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

We can feed our tokenizer directly with a pair of sentences.

In [7]:
encoded_input = tokenizer(
    'What is the step by step guide to invest in share market in india?',
    'What is the step by step guide to invest in share market?'
)
encoded_input
Out[7]:
{'input_ids': [101, 2054, 2003, 1996, 3357, 2011, 3357, 5009, 2000, 15697, 1999, 3745, 3006, 1999, 2634, 1029, 102, 2054, 2003, 1996, 3357, 2011, 3357, 5009, 2000, 15697, 1999, 3745, 3006, 1029, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

Decoding the tokenized inputs, this model's tokenizer adds some special tokens such as, [SEP], that is used to indicate which token belongs to which segment/pair.

In [8]:
tokenizer.decode(encoded_input["input_ids"])
Out[8]:
'[CLS] what is the step by step guide to invest in share market in india? [SEP] what is the step by step guide to invest in share market? [SEP]'

The proprocessing step will be task specific, if we happen to be using another dataset, this function needs to be modified accordingly.

In [9]:
def tokenize_fn(examples):
    labels = [int(label) for label in examples['is_duplicate']]
    texts = [question['text'] for question in examples['questions']]
    texts1 = [text[0] for text in texts]
    texts2 = [text[1] for text in texts]
    tokenized_examples = tokenizer(texts1, texts2)
    tokenized_examples['labels'] = labels
    return tokenized_examples
In [10]:
dataset_dict_tokenized = dataset_dict.map(
    tokenize_fn,
    batched=True,
    num_proc=8,
    remove_columns=['is_duplicate', 'questions']
)
dataset_dict_tokenized
                
#0:   0%|          | 0/41 [00:00<?, ?ba/s]
#1:   0%|          | 0/41 [00:00<?, ?ba/s]
#2:   0%|          | 0/41 [00:00<?, ?ba/s]
#3:   0%|          | 0/41 [00:00<?, ?ba/s]
#4:   0%|          | 0/41 [00:00<?, ?ba/s]
#6:   0%|          | 0/41 [00:00<?, ?ba/s]
#5:   0%|          | 0/41 [00:00<?, ?ba/s]
#7:   0%|          | 0/41 [00:00<?, ?ba/s]
                
#0:   0%|          | 0/5 [00:00<?, ?ba/s]
#1:   0%|          | 0/5 [00:00<?, ?ba/s]
#2:   0%|          | 0/5 [00:00<?, ?ba/s]
#3:   0%|          | 0/5 [00:00<?, ?ba/s]
#4:   0%|          | 0/5 [00:00<?, ?ba/s]
#5:   0%|          | 0/5 [00:00<?, ?ba/s]
#6:   0%|          | 0/5 [00:00<?, ?ba/s]
#7:   0%|          | 0/5 [00:00<?, ?ba/s]
                
#7:   0%|          | 0/6 [00:00<?, ?ba/s]
#4:   0%|          | 0/6 [00:00<?, ?ba/s]
#6:   0%|          | 0/6 [00:00<?, ?ba/s]
#2:   0%|          | 0/6 [00:00<?, ?ba/s]
#1:   0%|          | 0/6 [00:00<?, ?ba/s]
#5:   0%|          | 0/6 [00:00<?, ?ba/s]
#3:   0%|          | 0/6 [00:00<?, ?ba/s]
#0:   0%|          | 0/6 [00:00<?, ?ba/s]
Out[10]:
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 327474
    })
    val: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 36387
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 40429
    })
})
In [11]:
dataset_dict_tokenized['train'][0]
Out[11]:
{'input_ids': [101,
  2052,
  1045,
  2022,
  2583,
  2000,
  10887,
  2048,
  2797,
  14766,
  1998,
  2059,
  2131,
  2068,
  2000,
  3582,
  2169,
  2060,
  1029,
  102,
  2054,
  2097,
  4148,
  2065,
  1045,
  10887,
  2048,
  2797,
  18145,
  2000,
  3582,
  2169,
  2060,
  1029,
  102],
 'token_type_ids': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1],
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1],
 'labels': 1}

Model FineTuning

Having preprocessed our raw dataset, for our text classification task, we use AutoModelForSequenceClassification class to load the pre-trained model, the only other argument we need to specify is number of class/label our text classification task has. Upon instantiating this model for the first time, we'll see some warnings generated, telling us we should fine tune this model on our down stream tasks before using it.

In [12]:
model_checkpoint = 'text_classification'
num_labels = 2
In [13]:
# we'll save the model after fine tuning it once, so we can skip the fine tuning part during
# the second round if we detect that we already have one available
if os.path.isdir(model_checkpoint):
    model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint).to(device)
else:
    model = AutoModelForSequenceClassification.from_pretrained(
        pretrained_model_name_or_path,
        num_labels=num_labels
    ).to(device)

print('# of parameters: ', model.num_parameters())
Some weights of the model checkpoint at google/mobilebert-uncased were not used when initializing MobileBertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.dense.weight']
- This IS expected if you are initializing MobileBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MobileBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of MobileBertForSequenceClassification were not initialized from the model checkpoint at google/mobilebert-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
# of parameters:  24582914
In [14]:
data_collator = DataCollatorWithPadding(tokenizer, padding=True)
data_collator
Out[14]:
DataCollatorWithPadding(tokenizer=PreTrainedTokenizerFast(name_or_path='google/mobilebert-uncased', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}), padding=True, max_length=None, pad_to_multiple_of=None, return_tensors='pt')

We can perform all sorts of hyper parameter tuning on the fine tuning step, here we'll pick some default parameters for illustration purposes.

In [15]:
batch_size = 64
args = TrainingArguments(
    "quora_fine_tuned",
    learning_rate=1e-4,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=2,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True
)

trainer = Trainer(
    model,
    args,
    data_collator=data_collator,
    train_dataset=dataset_dict_tokenized["train"],
    eval_dataset=dataset_dict_tokenized['val']
)
In [16]:
if not os.path.isdir(model_checkpoint):
    trainer.train()
    model.save_pretrained(model_checkpoint)
/home/mingyuliu/.local/lib/python3.7/site-packages/transformers/optimization.py:310: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  FutureWarning,
***** Running training *****
  Num examples = 327474
  Num Epochs = 2
  Instantaneous batch size per device = 64
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 10234
[10234/10234 42:58, Epoch 2/2]
Epoch Training Loss Validation Loss
1 0.267200 0.254776
2 0.201200 0.237408

***** Running Evaluation *****
  Num examples = 36387
  Batch size = 64
Saving model checkpoint to quora_fine_tuned/checkpoint-5117
Configuration saved in quora_fine_tuned/checkpoint-5117/config.json
Model weights saved in quora_fine_tuned/checkpoint-5117/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 36387
  Batch size = 64
Saving model checkpoint to quora_fine_tuned/checkpoint-10234
Configuration saved in quora_fine_tuned/checkpoint-10234/config.json
Model weights saved in quora_fine_tuned/checkpoint-10234/pytorch_model.bin


Training completed. Do not forget to share your model on huggingface.co/models =)


Loading best model from quora_fine_tuned/checkpoint-10234 (score: 0.23740820586681366).
Configuration saved in text_classification/config.json
Model weights saved in text_classification/pytorch_model.bin
In [17]:
class SoftmaxModule(nn.Module):
    """
    Add a softmax layer on top the base model. Note this does not necessarily
    mean the output score is a well-calibrated probability.
    """

    def __init__(self, model_path: str):
        super().__init__()
        self.model_path = model_path
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path)

    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        scores = nn.functional.softmax(outputs.logits, dim=-1)[:, 1]
        return scores
In [18]:
softmax_module = SoftmaxModule(model_checkpoint).to(device)
softmax_module.eval()
print('# of parameters: ', sum(p.numel() for p in softmax_module.parameters() if p.requires_grad))
loading configuration file text_classification/config.json
Model config MobileBertConfig {
  "_name_or_path": "text_classification",
  "architectures": [
    "MobileBertForSequenceClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_activation": false,
  "classifier_dropout": null,
  "embedding_size": 128,
  "hidden_act": "relu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 512,
  "initializer_range": 0.02,
  "intermediate_size": 512,
  "intra_bottleneck_size": 128,
  "key_query_shared_bottleneck": true,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "mobilebert",
  "normalization_type": "no_norm",
  "num_attention_heads": 4,
  "num_feedforward_networks": 4,
  "num_hidden_layers": 24,
  "pad_token_id": 0,
  "problem_type": "single_label_classification",
  "torch_dtype": "float32",
  "transformers_version": "4.20.1",
  "trigram_input": true,
  "true_hidden_size": 128,
  "type_vocab_size": 2,
  "use_bottleneck": true,
  "use_bottleneck_attention": false,
  "vocab_size": 30522
}

loading weights file text_classification/pytorch_model.bin
All model checkpoint weights were used when initializing MobileBertForSequenceClassification.

All the weights of MobileBertForSequenceClassification were initialized from the model checkpoint at text_classification.
If your task is similar to the task the model of the checkpoint was trained on, you can already use MobileBertForSequenceClassification for predictions without further training.
# of parameters:  24582914

We define some helper functions to generate predictions for our dataset, store the predicted score and label into a pandas DataFrame.

In [19]:
def predict(model, examples, round_digits: int = 5):
    input_ids = examples['input_ids'].to(device)
    attention_mask = examples['attention_mask'].to(device)
    token_type_ids = examples['token_type_ids'].to(device)
    batch_labels = examples['labels'].detach().cpu().numpy().tolist()
    model.eval()
    with torch.no_grad():
        batch_output = model(input_ids, attention_mask, token_type_ids)

    batch_scores = np.round(batch_output.detach().cpu().numpy(), round_digits).tolist()
    return batch_scores, batch_labels
In [20]:
def predict_data_loader(model, data_loader: DataLoader) -> pd.DataFrame:
    scores = []
    labels = []
    for examples in data_loader:
        batch_scores, batch_labels = predict(model, examples)
        scores += batch_scores
        labels += batch_labels

    df_predictions = pd.DataFrame.from_dict({'scores': scores, 'labels': labels})
    return df_predictions
In [21]:
data_collator = DataCollatorWithPadding(tokenizer, padding=True)
data_loader = DataLoader(dataset_dict_tokenized['test'], collate_fn=data_collator, batch_size=128)
start = time.time()
df_predictions = predict_data_loader(softmax_module, data_loader)
end = time.time()

print('elapsed: ', end - start)
print(df_predictions.shape)
df_predictions.head()
elapsed:  30.936853170394897
(40429, 2)
Out[21]:
scores labels
0 0.00001 0
1 0.00000 0
2 0.04782 0
3 0.00012 0
4 0.00442 0

Model Calibration

Temperature Scaling

Temperature Scaling is a post-processing technique that was proposed to improve upon the calibration error, but specifically designed for deep learning. It works by dividing the logits (output of the layer right before the final softmax layer) by a learned scalar parameter.

\begin{align} \text{softmax} = \frac{e^{(z/T)}}{\sum_i e^{(z_i/T)}} \end{align}

where $z$ is the logit, and $T$ is the learned temperature scaling parameter. We learn this parameter on a validation set, where $T$ is chosen to minimize negative log likelihood. As we can imagine, with $T \ge 1$, it lowers the predicted score across all classes, making the model less confident about its predictions but does not change the model's predicted maximum class.

The benefit of this approach is mainly two folds:

  • Unlike a lot of post processing calibration technique, temperature scaling can be directly embedded into our deep learning module as a single additional parameter. We can export the model as is using standard serialization techniques for that specific deep learning library and perform inferencing at run time without introducing additional dependencies.
  • It has been shown to provide potent calibration performance when compared to other post processing calibration techniques by the original paper.
In [22]:
class TemperatureScalingCalibrationModule(nn.Module):

    def __init__(self, model_path: str):
        super().__init__()
        self.model_path = model_path
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path)

        # the single temperature scaling parameter, the initialization value doesn't
        # seem to matter that much based on some ad-hoc experimentation
        self.temperature = nn.Parameter(torch.ones(1))

    def forward(self, input_ids, attention_mask, token_type_ids):
        """forward method that returns softmax-ed confidence scores."""
        outputs = self.forward_logit(input_ids, attention_mask, token_type_ids)
        scores = nn.functional.softmax(outputs, dim=-1)[:, 1]
        return scores

    def forward_logit(self, input_ids, attention_mask, token_type_ids):
        """forward method that returns logits, to be used with cross entropy loss."""
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        ).logits
        return outputs / self.temperature

    def fit(self, dataset_tokenized, n_epochs: int = 3, batch_size: int = 64, lr: float = 0.01):
        """fits the temperature scaling parameter."""
        data_collator = DataCollatorWithPadding(tokenizer, padding=True)
        data_loader = DataLoader(dataset_tokenized, collate_fn=data_collator, batch_size=batch_size)

        self.freeze_base_model()
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(self.parameters(), lr=lr)

        for epoch in trange(n_epochs):
            for examples in data_loader:
                labels = examples['labels'].long().to(device)
                input_ids = examples['input_ids'].to(device)
                attention_mask = examples['attention_mask'].to(device)
                token_type_ids = examples['token_type_ids'].to(device)

                # standard step to perform the forward and backward step
                self.zero_grad()
                predict_proba = self.forward_logit(input_ids, attention_mask, token_type_ids)
                loss = criterion(predict_proba, labels)
                loss.backward()
                optimizer.step()

        return self

    def freeze_base_model(self):
        """remember to freeze base model's parameters when training temperature scaler"""
        self.model.eval()
        for parameter in self.model.parameters():
            parameter.requires_grad = False

        return self
In [23]:
calibration_module = TemperatureScalingCalibrationModule(model_checkpoint).to(device)
calibration_module.fit(dataset_dict_tokenized['val'])
print('# of parameters: ', sum(p.numel() for p in calibration_module.parameters() if p.requires_grad))
loading configuration file text_classification/config.json
Model config MobileBertConfig {
  "_name_or_path": "text_classification",
  "architectures": [
    "MobileBertForSequenceClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_activation": false,
  "classifier_dropout": null,
  "embedding_size": 128,
  "hidden_act": "relu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 512,
  "initializer_range": 0.02,
  "intermediate_size": 512,
  "intra_bottleneck_size": 128,
  "key_query_shared_bottleneck": true,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "mobilebert",
  "normalization_type": "no_norm",
  "num_attention_heads": 4,
  "num_feedforward_networks": 4,
  "num_hidden_layers": 24,
  "pad_token_id": 0,
  "problem_type": "single_label_classification",
  "torch_dtype": "float32",
  "transformers_version": "4.20.1",
  "trigram_input": true,
  "true_hidden_size": 128,
  "type_vocab_size": 2,
  "use_bottleneck": true,
  "use_bottleneck_attention": false,
  "vocab_size": 30522
}

loading weights file text_classification/pytorch_model.bin
All model checkpoint weights were used when initializing MobileBertForSequenceClassification.

All the weights of MobileBertForSequenceClassification were initialized from the model checkpoint at text_classification.
If your task is similar to the task the model of the checkpoint was trained on, you can already use MobileBertForSequenceClassification for predictions without further training.
100%|██████████| 3/3 [01:11<00:00, 24.00s/it]
# of parameters:  1

In [24]:
calibration_module.temperature
Out[24]:
Parameter containing:
tensor([1.1972], device='cuda:0', requires_grad=True)
In [25]:
data_collator = DataCollatorWithPadding(tokenizer, padding=True)
data_loader = DataLoader(dataset_dict_tokenized['test'], collate_fn=data_collator, batch_size=128)
start = time.time()
df_calibrated_predictions = predict_data_loader(calibration_module, data_loader)
end = time.time()

print('elapsed: ', end - start)
print(df_calibrated_predictions.shape)
df_calibrated_predictions.head()
elapsed:  30.429730653762817
(40429, 2)
Out[25]:
scores labels
0 0.00006 0
1 0.00000 0
2 0.07596 0
3 0.00051 0
4 0.01071 0

Observations:

  • Based on our calibration plot below, we can see our predicted score on this particular datset is concentrated on the higher end. Though it also seems like our original predicted score is already pretty well calibrated, and with temperature scaling, we were able to improve upon the calibration metrics even further.
  • A trained temperature scaling parameter that's larger than 1 value indicates that it is indeed shrinking the predicted score to make our model less confident on its prediction.
In [26]:
from calibration_module.utils import compute_calibration_summary


eval_dict = {
    f'{model_checkpoint}': df_predictions,
    f'{model_checkpoint}_calibrated': df_calibrated_predictions
}

# change default style figure and font size
plt.rcParams['figure.figsize'] = 12, 12
plt.rcParams['font.size'] = 12

n_bins = 20
df_result = compute_calibration_summary(eval_dict, label_col='labels', score_col='scores', n_bins=n_bins)
df_result
Out[26]:
auc precision recall f1 log_loss brier calibration_error name
0 0.9627 0.8283 0.9091 0.8668 0.2430 0.0749 0.0415 text_classification
1 0.9627 0.8283 0.9091 0.8668 0.2386 0.0742 0.0313 text_classification_calibrated

There're other works [3] that studies calibration effect for state of the art models. Although it's mainly for image based models, their claim is that model size and pretraining amount don't fully account for the differences in calibration across different models, but primary factor seems to be on model architecture, or more explicitly models that rely on attention based mechanism are found to be better calibrated compared to convolution based mechanism.

Reference

  • [1] Blog: Temperature Scaling for Neural Network Calibration
  • [2] Chuan Guo, Geoff Pleiss, Yu Sun, et al. - On Calibration of Modern Neural Networks (2017)
  • [3] Matthias Minderer, Josip Djolonga, Rob Romijnders, Frances Hubis, Xiaohua Zhai, Neil Houlsby, Dustin Tran, Mario Lucic - Revisiting the Calibration of Modern Neural Networks (2021)