None
%load_ext watermark
%watermark -a 'Ethen' -d -u
From a practical stand point, sentence embeddings are particular useful in:
In this transformer-based deep learning era, it is often beneficial to warm start our models from pre-trained checkpoints. One important caveat is popular pre-trained models trained using masked language modeling objective such as BERT [16] or RoBERTA [17] doesn't generate meaningful sentence embeddings as per one of the original authors. SBERT (Sentence Bert) [5] further demonstrated it is beneficial to perform additional fine-tuning on top of pre-trained transformer based language models for competitive sentence embedding performance.
For fine-tuning procedure, many works including the original SBERT relied on classification, regression, pairwise triplet/hinge loss style objective function [5] [14] [15]. But one of SBERT authors later pointed out that contrastive learning is a much better approach [1] and E5 (EmbEddings from bidirEctional Encoder rEpresentations) [13] which also involves pre-training stage with contrastive learning tops the retrieval benchmark suite from MTEB (Massive Text Embedding Benchmark) [10]. So here we are, discussing some recipes for training sentence embeddings via contrastive learning.
As with most other use cases, we can of course use a more powerful encoder to generate the embedding representations for our input examples, but some tips specific to improve the performance for contrastive loss based learning involves:
In recent years, most contrastive learning procedures leverage variants of InfoNCE (Noise Contrastive Estimation) loss [11]. This type of loss, sometimes referred to as NT-Xent (normalized temperature scaled cross entropy loss) in as SimCLR (simple contrastive learning of visual representations) [12] or multiple negative ranking loss [2], uses cross entropy loss to distinguish positive from potentially multiple negative examples.
\begin{aligned} \mathcal{L} = & -\log \frac{\exp \left(\operatorname{sim}\left(\boldsymbol{a}_i, \boldsymbol{p}_i\right) / \tau\right)}{\exp \left(\operatorname{sim}\left(\boldsymbol{a}_i, \boldsymbol{p}_i\right) / \tau\right) + \sum_{j=1}^m \exp \left(\operatorname{sim}\left(\boldsymbol{a}_i, \boldsymbol{n}_j\right) / \tau\right)} \end{aligned}For training this loss function, we need to have pairs of anchors and its corresponding positive example. Here:
Note if choosing cosine similarity as similarity function, and $\tau$ is chosen by hand, we might need to lean towards a smaller value. As by default, cosine similarity's score differences are too small and doesn't lead to good empirical results.
CLIP (Contrastive Language-Image Pre-Training) [8] introduces a symmetric version of this loss function. PyTorch style pseudocode is shown below:
# @ denotes matrix multiplication, and represents a in batch negative
# operation mentioned in the next section.
scores = anchor_embedding @ positive_embedding.T / temperature
labels = torch.arange(len(scores), device=scores.device, dtype=torch.long)
cross_entropy_loss(scores, labels)
# clip style symmetric version
anchor_scores = anchor_embedding @ positive_embedding.T / temperature
positive_scores = positive_embedding @ anchor_embedding.T / temperature
labels = torch.arange(len(anchor_scores), device=scores.device, dtype=torch.long)
loss = (cross_entropy_loss(anchor_scores, labels) + cross_entropy_loss(positive_scores, labels)) / 2
anchor_embedding
and positive_embedding
are embedding representation of anchor and positive samples, in the shape of [batch size, hidden size]
. They are generated by backbone encoder (e.g. transformer) with or without linear projection of our choice. The encoder can even share weights, which is commonly referred to as siamese network.
In batch negatives is widely used for training models with contrastive loss. Assuming there are $B$ positive pairs for a given mini batch, each of these positive pairs can be paired with $B - 1$ negatives (rest of the positive pairs within the same mini batch). This paradigm allows us to leverage the already loaded mini batch rather than using additional resources to sample negative examples. When relying on in batch negative sampling, using a larger batch size is key, as it allows the loss function to optimize over a more diverse set of negative samples. i.e. it's easier to find the right answer over a pool of 2 candidates, versus say 1024 candidates. This can be treated as an implicit version of hard negative mining.
Work such as RocketQA [6] and CLIP [8] further mentions the use of cross gpu/batch negatives, where when training on multiple GPUs, the embedding calculation can be sharded within each single GPU, these embeddings can then be shared among all GPUs and serves as additional negative examples. i.e. for A GPUs, we can now collect $A \times B - 1$ negatives. Note, sharing here refers to a differentiable all gather operation.
With this approach, CLIP reports to be using a effective batch size as large as 32,768 sharded across 256 GPUs, it should be noted that the optimal size will be dependent on training data size, where CLIP's training data consists of more than 400M examples. As with all hyper parameters, the optimal setting will be use-case dependent. When pressed for time, this is most likely the parameter that we should prioritize on tuning above everything else.
Given the importance of large batch size in contrastive learning, it can be helpful to employ techniques that:
By increasing training batch size for our in-batch negative sampling, we can generate a larger number of negative samples from in batch negative sampling. However, many of these negatives may be easily distinguishable by our model, thereby providing limited additional information after a certain point. To address this, we need a mechanism to identify hard negatives, which are more challenging for the model to discern. In other words, instead of providing pairs of anchors, $a_i$, and positives, $p_i$ as our input data. We now provide triplets $a_i, p_i, n_i$, where negative $n_i$ should be similar to $p_i$ but shouldn't match with $a_i$. While forming a contrastive learning training batch, these explicit hard negatives will be concatenated along with other in batch negatives as shown in the diagram below.
It is worth noting:
Primary strategies of providing these hard negatives includes:
With this strategy, we need to be mindful and ensure these examples are actually negatives. It is typically infeasible to scan through our entire dataset and label all the positive examples for a given anchor. Hence it can happen when sampling hard negatives, we might accidentally sample a positive example that wasn't labeled, introducing false negatives to the mix. To solve for this, we can train a separate cross-encoder model, which are typically more powerful at capturing semantic similarity compared to bi-encoder to denoise our hard negatives. In other words, when sampling hard negatives from the top ranked examples using aforementioned strategies, we can only select the ones that are predicted as negatives by the cross encoder with high confidence score.
Training with a mix of both random, hard negatives is often times beneficial, what's the optimal proportion of the two is something we'll have to experiment with on our use case. The overall data augmentation workflow can be roughly summarized into the following steps [6] [7]:
BEIR (Benchmarking Information Retrieval) [20] showed these dense embedding retrieval models require large amounts of training data to work well. If we are lacking data for a particular domain of interest, one trend in informational retrieval is to leverage generative models to generate synthetic data for training ranking models. At its core, the idea is large generative models have demonstrated impressive results across the board on many NLP tasks, however, they can be expensive to apply at run time. Hence, instead of using them directly in our system, we would like to apply them in an offline setting, where we use them to generate more in-domain training data for training our actual retrieval or ranking model. This arguably can be seen as an alternative form of distillation, where we aim to distill the knowledge of large generative models into small bi-encoder or cross encoder models via prompt generation. This high level workflow is depicted in the figure below:
To elaborate:
We can imagine many variations that can stem out of this workflow: including different prompting template as well as which language or seq2seq model to use, choice of re-ranker models etc. Work such as InPars (Inquisitive Parrots for Search) v1/v2/light [22] [23] [24], Prompt base Query Generation for Retriever (PROMPTAGATOR) [25], GPL (Generative Pseudo Labeling) [26] all demonstrates different design choices and its effectiveness to various degree.
Note, in pure image field work such as Moco (Momentum Contrast) [18] introduces the concept of memory queue and momentum encoder for improving contrastive learning results. Both of them are not elaborated upon here as:
We can refer to public pre-trained sentence transformer's model card [3], and E5 [13] on public datasets that can be combined and used to fine-tune these models .
As with all methods, it is important to find an established benchmark dataset so we can quickly iterate on new ideas. MTEB [10], has collected 8 embedding tasks ranging from semantic textual similarity (STS, SemEval), classification (fine tuning a classifier using the embedding as input features, SentEval), information retrieval (BEIR [20]), etc., in total it consists of 56 datasets, covering 112 languages. They also evaluated 30 different models to provide a holistic view of state of the art public pre-trained text embedding models.
Two tower based models are extremely popular in industrial for various embedding based retrieval use case, and for good reasons:
Non-exhaustive list of their presence in the industry includes: Amazon's semantic product search [14] as well as query rewrite [35], Facebook search's multi-modal embedding based retrieval [15] [29], Google play's retrieval system [30], Youtube's fresh content retrieval [36], Bing sponsored search's multi-objective retrieval system [31].
Even in the context of modern large language model, retrieval augmented generation [33] is still a highly efficient method to avoid stuffing everything inside a model's prompt, and having to fear the need to re-train this model to prevent generating potentially out-dated information.