None sentence_transformer
In [1]:
# code for loading the format for the notebook
import os

# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', 'notebook_format'))

from formats import load_style
load_style(css_style='custom2.css', plot_style=False)
Out[1]:
In [1]:
os.chdir(path)

# 1. magic for inline plot
# 2. magic to print version
# 3. magic so that the notebook will reload external python modules
# 4. magic to enable retina (high resolution) plots
# https://gist.github.com/minrk/3301035
%matplotlib inline
%load_ext watermark
%load_ext autoreload
%autoreload 2

import os
import torch
import scipy
import datasets
import transformers
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from time import perf_counter
from torch.utils.data import DataLoader
from datasets import (
    load_dataset,
    concatenate_datasets,
    disable_progress_bar,
    DatasetDict
)
from datasets.utils.logging import set_verbosity_error
from transformers import (
    Trainer,
    TrainingArguments,
    AutoModel,
    AutoTokenizer,
    DataCollatorWithPadding,
    EarlyStoppingCallback,
    IntervalStrategy
)
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics.pairwise import paired_cosine_distances

device = "cuda" if torch.cuda.is_available() else "cpu"
cache_dir = None

%watermark -a 'Ethen' -d -u -iv
/usr/local/lib/python3.8/dist-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Author: Ethen

Last updated: 2023-04-15

datasets    : 2.11.0
pandas      : 1.4.3
scipy       : 1.9.0
numpy       : 1.23.2
torch       : 2.0.0
transformers: 4.28.1

Sentence Transformer: Training Bi-Encoder via Contrastive Loss

In modern recommendation system, it is common to have two major distinct stages, a recall/retrieval stage, and a ranking stage [2] [3]. In this article, we'll be discussing these two stages in the context of a pre-trained encoder model, where different types of architectures are commonly used at different stages to reflect their own strength.

Recall/Retrieval focuses on candidate generation, given we have a million things be it inventories, passages, etc. stored in our database, we need a way to efficiently retrieve a subset of them that are more relevant for the given context. In encoder model's paradigm, this means we need a way to represent our incoming context be it query, questions, etc. in a vector representation, often times referred to as sentence embedding and perform a similarity search between that with entities that are also already represented in a sentence embedding. This step is often done using the bi-encoder architecture, where we can pass individual sentences/entities through our encoder and then perform similarity search using metrics such as cosine similarity. Being able to input individual entities through our encoder is the key to an effient retrieval stage, where we need to quickly scan through our database, as it means we can often times pre-compute these results apriori.

These subsets are then passed to the ranker, where given we now have our pairs, we need to score these pairs so the system can assign a score to rank these retrieved inventories in a sorted order. This step is often done using cross encoder, a.k.a. cross attention architectures. Cross encoder doesn't produce sentence embeddings for our data but instead relies on classification mechansim for out input pairs, hence whenever we have a fixed set of data pairs we wish to score, cross encoder often times achieves better performance than their bi encoder counterparts even when using less training data.

In this document, our focus will be how to train bi-encoder style architecture using contrastive loss [1] [4].

Contrastive Loss

For Bi-encoder, our input data at the bare minimum needs to contain positive pairs, where we need to be creative and define what's the definition for positive pairs based on our use case. These pairs are often referred to as anchors $a$ and positives, $p$. Given our dataset, the standard architecture for these type of task is a siamese network. Each base arm involves using a pre-trained encoder followed by a pooling operator for deriving a fix sized sentence embedding vector. During each step we feed both our anchor and positive one at time through the same encoder plus pooling arm. This encoder plus pooling arm's weights are often tied for our pairs. Tying weights ensure similar entities are mapped to similar locations in the representation space.

For training this network, what we wish to accomplish is have our anchor and positive pairs $a_i$ and $p_i$ become closer in the vector space, whereas $a_i$ and some random negative examples $p_j$ becomes distant in vector space.

\begin{align} L = -\frac{1}{n} \sum^n_{i=1} \frac{exp(sim(a_i, p_i))}{\sum_j exp(sim(a_i, p_j))} \end{align}

Where $sim$ is a similarity function such as cosine similarity or dot product. This can then be treated as a classification task using cross entropy loss. $p_j$ is often times sampled using in batch negatives, meaning for each example in a batch, other examples' positives will be taken as its negatives, this effectively re-uses the already sampled data in our current batch without the need to implement additional negative sampling logic.

Dataset

For our dataset, we'll be using Stanford Natural Language Inference (SNLI) and Multi-Genre NLI (MNLI), refer to its link for more details on these two datasets. For constrastive loss, we will filter out premise and hypothesis pairs that has a label 0 assigned to it. These are positive samples where the premise suggests the hypothesis.

We'll also use the Semantic Textual Similarity Benchmark (stsb) dataset each contains sentence pairs with human annotated similarity score from 1 to 5 for evaluating our model's performance

In [2]:
# prevents progress bar and logging from flooding our document
disable_progress_bar()
set_verbosity_error()

stsb = load_dataset('glue', 'stsb', cache_dir=cache_dir)

# we normalize the score to a range of 0 ~ 1
stsb = stsb.map(lambda x: {'label': x['label'] / 5.0})
stsb = stsb.rename_column("sentence1", "anchor_text")
stsb = stsb.rename_column("sentence2", "positive_text")
stsb = stsb.remove_columns(["idx"])
stsb
Out[2]:
DatasetDict({
    train: Dataset({
        features: ['anchor_text', 'positive_text', 'label'],
        num_rows: 5749
    })
    validation: Dataset({
        features: ['anchor_text', 'positive_text', 'label'],
        num_rows: 1500
    })
    test: Dataset({
        features: ['anchor_text', 'positive_text', 'label'],
        num_rows: 1379
    })
})
In [3]:
stsb["validation"][1]
Out[3]:
{'anchor_text': 'A young child is riding a horse.',
 'positive_text': 'A child is riding a horse.',
 'label': 0.949999988079071}
In [4]:
snli = load_dataset('snli', cache_dir=cache_dir)
mnli = load_dataset('glue', 'mnli', cache_dir=cache_dir)
mnli = mnli.remove_columns(['idx'])
snli = snli.cast(mnli["train"].features)

dataset_train = concatenate_datasets([snli["train"], mnli["train"]])
dataset_train = dataset_train.filter(lambda x: x['label'] == 0)
dataset_train = dataset_train.rename_column("premise", "anchor_text")
dataset_train = dataset_train.rename_column("hypothesis", "positive_text")

# note that we are evaluating the performance of our sentence embedding on
# stsb dataset without using any stsb samples in our training dataset
dataset_dict = DatasetDict({
    "train": dataset_train,
    "validation": stsb["train"],
    "test": stsb["validation"]
})
dataset_dict
Out[4]:
DatasetDict({
    train: Dataset({
        features: ['anchor_text', 'positive_text', 'label'],
        num_rows: 314315
    })
    validation: Dataset({
        features: ['anchor_text', 'positive_text', 'label'],
        num_rows: 5749
    })
    test: Dataset({
        features: ['anchor_text', 'positive_text', 'label'],
        num_rows: 1500
    })
})
In [5]:
dataset_dict["train"][0]
Out[5]:
{'anchor_text': 'A person on a horse jumps over a broken down airplane.',
 'positive_text': 'A person is outdoors, on a horse.',
 'label': 0}
In [6]:
# we can experiment with different pre-trained model checkpoint, e.g. "roberta-base"
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir)
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).

As usual, we tokenize our anchor and positive text.

In [7]:
def tokenize_fn(examples):
    tokenized = tokenizer(examples["anchor_text"], truncation=True, max_length=128)
    examples["anchor_input_ids"] = tokenized["input_ids"]
    examples["anchor_attention_mask"] = tokenized["attention_mask"]

    tokenized = tokenizer(examples["positive_text"], truncation=True, max_length=128)
    examples["positive_input_ids"] = tokenized["input_ids"]
    examples["positive_attention_mask"] = tokenized["attention_mask"]
    return examples
In [8]:
dataset_dict_tokenized = dataset_dict.map(
    tokenize_fn,
    batched=True,
    num_proc=8,
    remove_columns=["anchor_text", "positive_text"]
)
dataset_dict_tokenized
Out[8]:
DatasetDict({
    train: Dataset({
        features: ['label', 'anchor_input_ids', 'anchor_attention_mask', 'positive_input_ids', 'positive_attention_mask'],
        num_rows: 314315
    })
    validation: Dataset({
        features: ['label', 'anchor_input_ids', 'anchor_attention_mask', 'positive_input_ids', 'positive_attention_mask'],
        num_rows: 5749
    })
    test: Dataset({
        features: ['label', 'anchor_input_ids', 'anchor_attention_mask', 'positive_input_ids', 'positive_attention_mask'],
        num_rows: 1500
    })
})
In [9]:
dataset_dict_tokenized["train"][0]
Out[9]:
{'label': 0,
 'anchor_input_ids': [101,
  1037,
  2711,
  2006,
  1037,
  3586,
  14523,
  2058,
  1037,
  3714,
  2091,
  13297,
  1012,
  102],
 'anchor_attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 'positive_input_ids': [101,
  1037,
  2711,
  2003,
  19350,
  1010,
  2006,
  1037,
  3586,
  1012,
  102],
 'positive_attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

Here we also define a customize collate function for batching our dataset.

In [10]:
from dataclasses import dataclass
from typing import List, Dict, Optional, Union
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.file_utils import PaddingStrategy


@dataclass
class DataCollatorForSiamese:
    """
    Huggingface's DataCollatorWithPadding is for a single input, here we extend it
    to take paired inputs.
    
    References
    ----------
    https://huggingface.co/docs/transformers/main_classes/data_collator#transformers.DataCollatorWithPadding
    """
    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __call__(self, features: List[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
        anchor_features = {
            "input_ids": [feature["anchor_input_ids"] for feature in features],
            "attention_mask": [feature["anchor_attention_mask"] for feature in features],
            "label": [feature["label"] for feature in features]
        }
        pos_features = {
            "input_ids": [feature["positive_input_ids"] for feature in features],
            "attention_mask": [feature["positive_attention_mask"] for feature in features]
        }
        anchor_batch = self.tokenizer.pad(
            anchor_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors,
        )
        pos_batch = self.tokenizer.pad(
            pos_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors,
        )

        batch = {
            "anchor_input_ids": anchor_batch["input_ids"],
            "anchor_attention_mask": anchor_batch["attention_mask"],
            "positive_input_ids": pos_batch["input_ids"],
            "positive_attention_mask": pos_batch["attention_mask"],
            "labels": anchor_batch["label"]
        }
        return batch
In [11]:
# sample output from a data loader
batch_size = 2
data_collator = DataCollatorForSiamese(tokenizer)
data_loader = DataLoader(dataset_dict_tokenized["train"], batch_size=batch_size, collate_fn=data_collator)
batch = next(iter(data_loader))
batch
You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Out[11]:
{'anchor_input_ids': tensor([[  101,  1037,  2711,  2006,  1037,  3586, 14523,  2058,  1037,  3714,
           2091, 13297,  1012,   102],
         [  101,  2336,  5629,  1998, 12015,  2012,  4950,   102,     0,     0,
              0,     0,     0,     0]]),
 'anchor_attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]]),
 'positive_input_ids': tensor([[  101,  1037,  2711,  2003, 19350,  1010,  2006,  1037,  3586,  1012,
            102],
         [  101,  2045,  2024,  2336,  2556,   102,     0,     0,     0,     0,
              0]]),
 'positive_attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0]]),
 'labels': tensor([0, 0])}

Model

This section defines our transformer followed by pooling encoder, siamese network, evaluation metric, then proceed with using transformer library's Trainer class for training our network.

In [12]:
class TransformerPoolingEncoder(nn.Module):

    def __init__(self, model_name: str, pooling_mode: str = "avg", cache_dir = None):
        super().__init__()
        if pooling_mode not in {"avg", "cls"}:
            raise ValueError(f"{pooling_mode} needs to one of avg, cls")

        self.base_model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir)
        self.pooling_mode = pooling_mode

    def forward(self, input_ids, attention_mask):
        output = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        if self.pooling_mode == "avg":
            # reshape attention mask to cover all hidden dimension
            mask = attention_mask.unsqueeze(dim=-1).expand_as(output.last_hidden_state)
            # perform mean pooling, but excluding attention mask
            pooled = torch.sum(output.last_hidden_state * mask, dim=1) / torch.clamp(
                mask.sum(dim=1), min=1e-9
            )
        elif self.pooling_mode == "cls":
            pooled = output.last_hidden_state[:, 0, :]

        return pooled
In [13]:
class SiameseModel(nn.Module):
    """
    The loss function we are using is similar to Multiple Negative Ranking Loss that's mentioned
    in the sentence bert package.
    
    The similarity function implemented here is cosine similarity, with cosine similarity, we often
    times add a scaling multiplier. As cosine similarity's range is between -1 and 1, this scaling
    factor ensures our loss are large enough between positive and negative samples for the network
    to train.
    
    References
    ----------
    https://www.sbert.net/docs/package_reference/losses.html#multiplenegativesrankingloss
    """
    
    def __init__(self, model_name: str, scale: float = 20.0, pooling_mode: str = "avg", cache_dir = None):
        super().__init__()
        self.encoder = TransformerPoolingEncoder(model_name, pooling_mode, cache_dir)
        self.scale = scale

    def forward(
        self,
        anchor_input_ids,
        anchor_attention_mask,
        positive_input_ids,
        positive_attention_mask,
        labels=None
    ):
        # [batch size, hidden dim]
        anchor_embedding = self.encoder(
            input_ids=anchor_input_ids,
            attention_mask=anchor_attention_mask
        )
        positive_embedding = self.encoder(
            input_ids=positive_input_ids,
            attention_mask=positive_attention_mask
        )

        # cosine similarity
        anchor_norm = F.normalize(anchor_embedding, p=2, dim=1)
        positive_norm = F.normalize(positive_embedding, p=2, dim=1)
        scores = torch.mm(anchor_norm, positive_norm.transpose(0, 1)) * self.scale

        # Example a[i] should match with p[i]
        labels = torch.arange(scores.size()[0], device=scores.device)
        loss = F.cross_entropy(scores, labels)
        return loss, anchor_embedding, positive_embedding
In [14]:
model = SiameseModel(model_name, cache_dir=cache_dir).to(device)
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
In [15]:
# sample forward output from the model given our input batch
for key, tensor in batch.items():
    batch[key] = tensor.to(device)

with torch.no_grad():
    output = model(**batch)
    
output
Out[15]:
(tensor(0.4709, device='cuda:0'),
 tensor([[-0.1386, -0.0835, -0.2738,  ..., -0.3862,  0.0325, -0.0017],
         [ 0.3649,  0.3073,  0.0864,  ..., -0.2507,  0.2876, -0.0527]],
        device='cuda:0'),
 tensor([[-0.0968, -0.2067, -0.0868,  ..., -0.4490, -0.0050,  0.0568],
         [-0.0548,  0.0111,  0.1266,  ..., -0.1976,  0.2901, -0.1404]],
        device='cuda:0'))

For our evaluation metric, we will first compute the cosine similarity of our pairs' embedding, then calculate pearson and spearman correlation with our ground truth similarity score.

In [16]:
def compute_metrics(eval_prediction):
    anchor_embedding, positive_embedding = eval_prediction.predictions
    labels = eval_prediction.label_ids

    cosine_sim = 1 - paired_cosine_distances(anchor_embedding, positive_embedding)
    pearson, _ = pearsonr(labels, cosine_sim)
    spearman, _ = spearmanr(labels, cosine_sim)
    return {"pearson": round(pearson, 3), "spearman": round(spearman, 3)}
In [17]:
os.environ['DISABLE_MLFLOW_INTEGRATION'] = 'TRUE'

finetuned_checkpoint = f"{model_name}-siamese"
training_args = TrainingArguments(
    output_dir=finetuned_checkpoint,
    num_train_epochs=2,
    learning_rate=0.0001,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=256,
    weight_decay=0.01,
    fp16=True,
    evaluation_strategy=IntervalStrategy.STEPS,
    save_strategy=IntervalStrategy.STEPS,
    eval_steps=500,
    save_steps=500,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="spearman"
)

trainer = Trainer(
    model,
    args=training_args, 
    data_collator=data_collator,
    train_dataset=dataset_dict_tokenized['train'],
    eval_dataset=dataset_dict_tokenized['validation'],
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
)
In [18]:
result = trainer.train()
/usr/local/lib/python3.8/dist-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(
[6500/9824 08:57 < 04:35, 12.08 it/s, Epoch 1/2]
Step Training Loss Validation Loss Pearson Spearman
500 0.300400 2.122955 0.782000 0.755000
1000 0.234200 2.186666 0.789000 0.760000
1500 0.216200 2.237437 0.791000 0.762000
2000 0.202000 2.220945 0.803000 0.776000
2500 0.184900 2.192533 0.800000 0.772000
3000 0.182200 2.195967 0.802000 0.774000
3500 0.172800 2.239089 0.807000 0.779000
4000 0.161300 2.302346 0.801000 0.773000
4500 0.161600 2.299685 0.799000 0.773000
5000 0.145900 2.234666 0.806000 0.781000
5500 0.105800 2.264370 0.807000 0.781000
6000 0.112700 2.267491 0.806000 0.777000
6500 0.104400 2.288190 0.808000 0.781000

In [19]:
trainer.evaluate(dataset_dict_tokenized["test"])
[6/6 00:00]
Out[19]:
{'eval_loss': 2.7663321495056152,
 'eval_pearson': 0.839,
 'eval_spearman': 0.838,
 'eval_runtime': 0.3629,
 'eval_samples_per_second': 4133.329,
 'eval_steps_per_second': 16.533,
 'epoch': 1.32}

Evaluation

Quantitative benchmark on our test set shows we are capable of achieving a 80+% spearman correlation. This section performs a manual inspection on some sample input sentences, where some are more similar compared to others.

In [20]:
# tokenize our input sentences
sentences = [
    "it caught him off guard that space smelled of seared steak",
    "she could not decide between painting her teeth or brushing her nails",
    "he thought there'd be sufficient time is he hid his watch",
    "the bees decided to have a mutiny against their queen",
    "the sign said there was road work ahead so she decided to speed up",
    "on a scale of one to ten, what's your favorite flavor of color?",
    "flying stinging insects rebelled in opposition to the matriarch"
]
batch = tokenizer(sentences, padding=True, max_length=128, truncation=True, return_tensors="pt")
batch
Out[20]:
{'input_ids': tensor([[  101,  2009,  3236,  2032,  2125,  3457,  2008,  2686,  9557,  1997,
          2712,  5596, 21475,   102,     0,     0,     0,     0,     0],
        [  101,  2016,  2071,  2025,  5630,  2090,  4169,  2014,  4091,  2030,
         12766,  2014, 10063,   102,     0,     0,     0,     0,     0],
        [  101,  2002,  2245,  2045,  1005,  1040,  2022,  7182,  2051,  2003,
          2002, 11041,  2010,  3422,   102,     0,     0,     0,     0],
        [  101,  1996, 13734,  2787,  2000,  2031,  1037, 19306,  2114,  2037,
          3035,   102,     0,     0,     0,     0,     0,     0,     0],
        [  101,  1996,  3696,  2056,  2045,  2001,  2346,  2147,  3805,  2061,
          2016,  2787,  2000,  3177,  2039,   102,     0,     0,     0],
        [  101,  2006,  1037,  4094,  1997,  2028,  2000,  2702,  1010,  2054,
          1005,  1055,  2115,  5440, 14894,  1997,  3609,  1029,   102],
        [  101,  3909, 22748,  9728, 25183,  1999,  4559,  2000,  1996, 13523,
          4360, 11140,   102,     0,     0,     0,     0,     0,     0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]])}
In [21]:
# feed it through our siamese network's encoder
with torch.no_grad():
    embeddings = model.encoder(batch["input_ids"].to(device), batch["attention_mask"].to(device))

embeddings = embeddings.cpu().numpy()
embeddings
Out[21]:
array([[ 0.6091151 , -0.05289527, -0.735597  , ..., -0.20718901,
         0.9009036 , -0.08248363],
       [ 0.39508554, -0.67105526, -1.0313998 , ..., -0.5491935 ,
        -0.6127097 ,  0.34172556],
       [ 0.41667876,  0.01123782,  0.18218265, ..., -0.04224749,
        -0.41398963, -0.81505966],
       ...,
       [ 0.52507514,  0.10277902,  0.3072502 , ...,  0.30447486,
        -0.4546693 ,  0.26497468],
       [-0.78691447, -0.33936325,  0.02374081, ..., -0.25195748,
         0.06752441,  0.20980802],
       [-0.8355608 , -0.08190179, -1.2289685 , ..., -0.11544642,
         0.06740702,  0.32486174]], dtype=float32)

Given our sentence embedding, we then perform a cosine similarity for finding semantically similar sentences. Notice our trained model was able to capture sentences that have very few matching words but are semantically similar.

The comparison is not shown here, but we can compare this with results from out of the box pre-trained models and see that further training these networks with constrastive loss makes day and night differences. i.e. if we were to directly retrieve sentence embeddings by either pooling its last hidden output/layer or extracting first token ([CLS] token)'s embedding, it often times yield rather poor sentence embeddings as mentioned in sentence BERT [5].

In [22]:
input_embedding = embeddings[-1].reshape(1, -1).repeat(embeddings[:-1].shape[0], axis=0)
candidate_embedding = embeddings[:-1]
scores = 1 - paired_cosine_distances(input_embedding, candidate_embedding)
scores
Out[22]:
array([ 0.18791032,  0.21843803, -0.05063367,  0.6107122 ,  0.07014787,
        0.02572602], dtype=float32)
In [23]:
print(sentences[-1])
for i, score in enumerate(scores):
    score = round(float(score), 4)
    sentence = sentences[i]
    print(f"{score} | {sentence}")
flying stinging insects rebelled in opposition to the matriarch
0.1879 | it caught him off guard that space smelled of seared steak
0.2184 | she could not decide between painting her teeth or brushing her nails
-0.0506 | he thought there'd be sufficient time is he hid his watch
0.6107 | the bees decided to have a mutiny against their queen
0.0701 | the sign said there was road work ahead so she decided to speed up
0.0257 | on a scale of one to ten, what's your favorite flavor of color?

In this document, we walked through a modern baseline approach for training sentence transformers. Modern in a sense that we are building on top of pre-trained language models instead of training from scratch, baseline meaning there're a lot more we can do to improve these models' performance, we'll take a look at those tricks in future documents.

Reference

  • [1] Blog: Next-Gen Sentence Embeddings with Multiple Negatives Ranking Loss
  • [2] Blog: Real World Recommendation System – Part 1
  • [3] Sentence Transformers Documentation: Retrieve & Re-Rank
  • [4] Github: Sentence Transformers Training Natural Language Inference
  • [5] Paper: Nils Reimers, Iryna Gurevych - Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks - 2019