None sentence_embedding_peft
In [1]:
# code for loading the format for the notebook
import os

# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', '..', 'notebook_format'))

from formats import load_style
load_style(css_style='custom2.css', plot_style=False)
Out[1]:
In [2]:
os.chdir(path)

%load_ext watermark
%load_ext autoreload
%autoreload 2

import math
import peft
import faiss
import torch
import datasets
import transformers
import numpy as np
import pandas as pd
import torch.nn as nn
import pytorch_lightning as pl
import torch.nn.functional as F
import torch.distributed as dist
from torch.optim import AdamW
from torch.distributed import nn as dist_nn
from torch.utils.data import DataLoader
from datasets import load_dataset
from peft import LoraConfig, TaskType, get_peft_model
from transformers.utils import PaddingStrategy
from transformers import PreTrainedTokenizerBase
from transformers.data.data_collator import DataCollatorMixin
from transformers import (
    AutoModel,
    AutoTokenizer,
    PretrainedConfig,
    PreTrainedModel,
    PreTrainedTokenizerBase
)
from dataclasses import dataclass
from typing import List, Optional, Dict, Any, Union

%watermark -a 'Ethen' -d -u -v -iv
Author: Ethen

Last updated: 2024-03-19

Python implementation: CPython
Python version       : 3.9.18
IPython version      : 8.18.1

datasets         : 2.14.7
pandas           : 2.2.0
numpy            : 1.23.5
torch            : 2.1.2
peft             : 0.9.0
faiss            : 1.7.2
transformers     : 4.37.0
pytorch_lightning: 2.1.4

Multilingual Sentence Embedding with LLM and PEFT LORA

In this article, we'll be taking a look at training a multilingual sentence embedding with Large Language Model (LLM) and a parameter efficient fine tuning technique: LoRA (Low Rank Adaption).

LLM For Retrieval

Large Language Model (LLM) with billion of parameters, fine-tuned to follow instructions have showcased remarkable capabilities on many NLP tasks. Consequently, there's a growing interest in harnessing these models as retrieval systems, such as LLAMA-2 [7], GPT [10], Mistral [11].

RepLLaMA/RankLLaMA [7] leverages LLAMA-2-7B as backbone model for training retrieval and re-ranker model. Previous work on dense retriever models often uses bi-directional encoder model like BERT, taking the representation of prepended special [CLS] token or average pooling as sentence embedding. Given LLAMA is a uni-directional decoder only, an end of sentence token <\s> is appended to serve as embedding.

For addressing high GPU memory cost associated with fine tuning large models with contrastive learning, they leverage memory efficiency solutions such as LoRA, flash attention, and gradient checkpointing. The model is trained on 16 x 32G V100 GPUs with a batch size of 128, hard negatives from a blend of BFM25 and CoCondenser to ensure hard negatives are derived from both sparse and dense retrieval results.

Apart from potent performance when evaluated on in-domain dataset MS MARCO and zero shot evaluation on BEIR benchmark suite, it also offers the advantage that modern LLM are often pre-trained with longer context window.

LoRA

In modern transformer pre-trained model era, many application rely on fine tuning one large pre-trained model to multiple down stream applications. Given the higher associated cost with fine tuning, many sought to adapt only partial parameters, i.e. freezing base layers. LoRA (Low Rank Adaptation) [9] presents an alternative approach by representing the weight update with two low rank matrices.

Quoting the LoRA paper: Given a weight matrix $W_0 \in \mathbb{R}^{d \times d}$, we would constrain its update $W_0 + \Delta W = W_0 + BA$, where $B \in \mathbb{R}^{d \times r}$, $B \in \mathbb{R}^{r \times d}$. During training $W_0$ is frozen, while $A$ and $B$ contain trainable parameters. Both set of matrices would receiving the same input during forward pass: $W_0 x + \alpha \Delta W x = W_0 x + \alpha BA x$, where $\alpha$ is a scaling constant. At the beginning, $A$ is initialized with random Gaussian, and zero for $B$.

Its advantages:

  • A pre-trained model can be shared, and use to build many small LoRA modules for different tasks.
  • Compared to full fine tuning, training becomes more efficien as it drastically reduces the number of trainable parameters. Lowering the hardware barrier as well as accelerating training cycle, especially when it comes to billion sized pre-trained models.
  • Its linear design allows us to merge LoRA's trainable matrices with the original frozen weights, effectively introducing zero additional inference latency compared to the original model.
# simplified lora linear layer
class LoraLinear(nn.Module):

    def __init__(
        self,
        in_features,
        out_features,
        rank,
        alpha,
        bias=True,
        device=None,
        dtype=None
    ):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.rank = rank
        self.alpha = alpha

        self.lora_A = nn.Parameters(torch.rand(in_features, rank))
        self.lora_B = nn.Parameters(torch.zeros(rank, out_features))

        # freeze the original linear layer's weight matrix
        self.weight.requires_grad = False

    def forward(self, x):
        lora_weights = x @ self.lora_A @ self.lora_B * self.alpha
        return super().forward(x) + lora_weights

Data

We'll be utilizing the bloomz model family as our tokenizer/model. We have the flexibility to substitute it with any other Language Model Models (LLMs), we've opted for the bloomz model family for its multilingual capabilities.

ESCI

For our dataset, we taking inspiration from one of the examples from peft library's [2] documentation. Specifically, we'll be using a small subset of ESCI e-commerce search query dataset that's conveniently available on huggingface dataset. The ESCI dataset [3] [8], available in multiple languages including English, Japanese, and Spanish, consists of challenging search queries (such as those involving negations: "energy bar without nuts" or "gluten-free biscuits") paired with up to 40 search results, along with their ESCI (Exact, Substitute, Complement, Irrelevant) judgments. Our task at hand will be to train a model for retrieving similar products for a given query.

In [3]:
# https://huggingface.co/bigscience/bloomz-1b7
model_name_or_path = "bigscience/bloomz-1b7"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
In [4]:
dataset_name = "smangrul/amazon_esci"
dataset_dict = load_dataset(dataset_name).filter(lambda example: example["relevance_label"] == 1)
print(dataset_dict["train"][0])
dataset_dict
Filter: 100%|██████████| 839306/839306 [00:07<00:00, 106693.67 examples/s]
Filter: 100%|██████████| 363402/363402 [00:03<00:00, 106600.45 examples/s]
{'query': '!awnmower tires without rims', 'product_title': 'MaxAuto 2-Pack 13x5.00-6 2PLY Turf Mower Tractor Tire with Yellow Rim, (3" Centered Hub, 3/4" Bushings )', 'product_id': 'B08L3B9B9P', 'esci_label': 'E', 'split': 'train', 'relevance_label': 1, '__index_level_0__': 17}

Out[4]:
DatasetDict({
    train: Dataset({
        features: ['query', 'product_title', 'product_id', 'esci_label', 'split', 'relevance_label', '__index_level_0__'],
        num_rows: 658894
    })
    validation: Dataset({
        features: ['query', 'product_title', 'product_id', 'esci_label', 'split', 'relevance_label', '__index_level_0__'],
        num_rows: 286542
    })
})
In [5]:
@dataclass
class DataCollatorForSentenceEmbedding(DataCollatorMixin):
    """
    tokenize raw text as well as padding while forming a batch for data loader.
    Append eos token for downstream embedding representation.
    """

    tokenizer: Optional[PreTrainedTokenizerBase] = None
    max_seq_len_1: int = 512
    max_seq_len_2: int = 512
    id_field: str = "__index_level_0__"
    text_field_1: str = "query"
    text_field_2: str = "product_title"
    process_tower: Optional[str] = None
    padding: Union[bool, str, PaddingStrategy] = True
    return_tensors: str = "pt"

    def __call__(self, features: List[Dict[str, Any]], return_tensors=None) -> Dict[str, Any]:
        # id could be a string column, and is also not part module's forward pass
        # hence converting to torch tensor isn't needed
        ids = [feature[self.id_field] for feature in features]

        if self.process_tower == "tower_1":
            formatted_text = [feature[self.text_field_1] + tokenizer.eos_token for feature in features]
            tokenized_text_1 = self.tokenizer(
                text=formatted_text,
                padding=self.padding,
                max_length=self.max_seq_len_1,
                truncation=True,
                return_attention_mask=True,
                return_tensors=self.return_tensors
            )

            batch = {
                "ids": ids,
                "input_ids": tokenized_text_1["input_ids"],
                "attention_mask": tokenized_text_1["attention_mask"]
            }
        elif self.process_tower == "tower_2":
            formatted_text = [feature[self.text_field_2] + tokenizer.eos_token for feature in features]
            tokenized_text_2 = self.tokenizer(
                text=formatted_text,
                padding=self.padding,
                max_length=self.max_seq_len_2,
                truncation=True,
                return_attention_mask=True,
                return_tensors=self.return_tensors
            )

            batch = {
                "ids": ids,
                "input_ids": tokenized_text_2["input_ids"],
                "attention_mask": tokenized_text_2["attention_mask"]
            }
        else:            
            formatted_text_1 = [feature[self.text_field_1] + tokenizer.eos_token for feature in features]
            tokenized_text_1 = self.tokenizer(
                text=formatted_text_1,
                padding=self.padding,
                max_length=self.max_seq_len_1,
                truncation=True,
                return_attention_mask=True,
                return_tensors=self.return_tensors
            )

            formatted_text_2 = [feature[self.text_field_2] + tokenizer.eos_token for feature in features]
            tokenized_text_2 = self.tokenizer(
                text=formatted_text_2,
                padding=self.padding,
                max_length=self.max_seq_len_2,
                truncation=True,
                return_attention_mask=True,
                return_tensors=self.return_tensors
            )

            batch = {
                "ids": ids,
                "input_ids_1": tokenized_text_1["input_ids"],
                "input_ids_2": tokenized_text_2["input_ids"],
                "attention_mask_1": tokenized_text_1["attention_mask"],
                "attention_mask_2": tokenized_text_2["attention_mask"]
            }

        return batch
In [6]:
data_collator = DataCollatorForSentenceEmbedding(tokenizer)
dataloader_train = DataLoader(
    dataset_dict["train"],
    shuffle=True,
    collate_fn=data_collator,
    batch_size=2,
    pin_memory=True,
)
batch = next(iter(dataloader_train))
batch
Out[6]:
{'ids': [699034, 746719],
 'input_ids_1': tensor([[ 24027,   2969,   8457,   2629, 170205,      2],
         [ 84846,   6303,   5669,   1640,  15486,      2]]),
 'input_ids_2': tensor([[     3,      3,      3,      3,      3,      3,      3,      3,   1980,
            6844, 150996,    337,   3846,    375,   3548,  13281,  78211,     12,
               2],
         [ 57277,  98007,  64937,   3541,  49761,  84109, 115011,  18832,   2967,
           86498,  67901,   2137,  18728,    530,   8557,  15486,     15,  21107,
               2]]),
 'attention_mask_1': tensor([[1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1]]),
 'attention_mask_2': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

Model

The following next code chunk defines a huggingface compatible SentenceEmbeddingModel for training retrieval model using contrastive learning. For actual LoRA experimentation, we'll directly leverage peft library.

In [7]:
class SentenceEmbeddingModelConfig(PretrainedConfig):
    model_type = "sentence_embedding"

    def __init__(
        self,
        model_name: str,
        normalize: bool = True,
        cross_gpu_negatives: bool = False,
        enable_gradient_checkpointing: bool = True,
        peft_config = None
    ):
        self.model_name = model_name
        self.normalize = normalize
        self.cross_gpu_negatives = cross_gpu_negatives and torch.cuda.device_count() > 1
        self.enable_gradient_checkpointing = enable_gradient_checkpointing
        self.peft_config = peft_config


class SentenceEmbeddingModel(PreTrainedModel):
    """
    InfoNCE style contrastive loss sentence embedding.
    Uses last token (eos) as embedding representation,
    gradient checkpointing, LoRA for memory efficient training
    """

    config_class = SentenceEmbeddingModelConfig

    def __init__(self, config):
        super().__init__(config)
        self.config = config

        model = AutoModel.from_pretrained(config.model_name)
        if config.enable_gradient_checkpointing:
            model.gradient_checkpointing_enable(
                gradient_checkpointing_kwargs={"use_reentrant": False}
            )

        self.model = model
        if config.peft_config:
            self.model = get_peft_model(model, config.peft_config)
            self.model.print_trainable_parameters()

        self.loss = CLIPLoss(config.cross_gpu_negatives)

    def forward(
        self,
        input_ids_1,
        attention_mask_1,
        input_ids_2,
        attention_mask_2,
    ):
        embeddings_1 = self.encode(input_ids_1, attention_mask_1)
        embeddings_2 = self.encode(input_ids_2, attention_mask_2)
        loss = self.loss(embeddings_1, embeddings_2)
        return loss, embeddings_1, embeddings_2

    def encode(self, input_ids, attention_mask):
        model_output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        embeddings = last_token_pooling(model_output.last_hidden_state, attention_mask)
        if self.config.normalize:
            embeddings = F.normalize(embeddings, p=2, dim=1)

        return embeddings


def last_token_pooling(last_hidden_states, attention_mask):
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]



class CLIPLoss(nn.Module):
    """
    Symmetric contrastive learning, a.k.a. CLIP loss or
    Multiple Negative Ranking Loss that's mentioned in the sentence bert package.

    References
    ----------
    - https://arxiv.org/abs/2103.00020
    - https://www.sbert.net/docs/package_reference/losses.html#multiplenegativesrankingloss
    """

    def __init__(self, cross_gpu_negatives: bool = True):
        super().__init__()
        self.cross_gpu_negatives = cross_gpu_negatives

        # trainable temperature parameters
        # This initial value is based on open clip
        # https://github.com/mlfoundations/open_clip/blob/4b929357093bfbb0986b61cfa23776f1dc740370/src/open_clip/model.py
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def forward(self, anchor_embedding, positive_embedding):
        with torch.no_grad():
            self.logit_scale.clamp_(0, math.log(100))

        logit_scale = self.logit_scale.exp()
        if self.cross_gpu_negatives:
            anchor_embedding_all_gathered = torch.cat(dist_nn.all_gather(anchor_embedding), dim=0)
            positive_embedding_all_gathered = torch.cat(dist_nn.all_gather(positive_embedding), dim=0)
            anchor_scores = anchor_embedding @ positive_embedding_all_gathered.T * logit_scale
            positive_scores = positive_embedding @ anchor_embedding_all_gathered.T * logit_scale
            rank = dist.get_rank()
        else:
            anchor_scores = anchor_embedding @ positive_embedding.T * logit_scale
            positive_scores = positive_embedding @ anchor_embedding.T * logit_scale
            rank = 0

        # Example a[i] should match with p[i]
        batch_size = anchor_scores.size()[0]
        labels = torch.arange(batch_size, device=anchor_scores.device, dtype=torch.long)
        labels = labels + batch_size * rank
        loss = (F.cross_entropy(anchor_scores, labels) + F.cross_entropy(positive_scores, labels)) / 2
        return loss

As part of our LoraConfig, we need to specify target_modules, which checks if the specified substring is in module's full name. LoRA can be applied to any module in our model, though the most common practice for transformer style model is applying to to attention layer's key, value, query matrices as well as its immediate feed forward layer.

With our LoRA setup along with gradient checkpointing, we are able to train a 1.7B model using a single V100 GPU with micro batch size of 64.

In [8]:
# https://huggingface.co/docs/peft/en/package_reference/lora
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    bias="none",
    task_type=TaskType.FEATURE_EXTRACTION,
    # check each model's corresponding module name,
    # e.g. for BERT, target_modules=["key", "query", "value"],
    target_modules=["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"],

    # parameters that were not injected with LoRA are automatically
    # frozen, if we wish to train them, specify them via
    # modules_to_save
)
sentence_embedding_model_config = SentenceEmbeddingModelConfig(
    model_name_or_path,
    peft_config=peft_config
)
sentence_embedding_model = SentenceEmbeddingModel(sentence_embedding_model_config)
sentence_embedding_model
trainable params: 6,291,456 || all params: 1,728,700,416 || trainable%: 0.363941371319714
Out[8]:
SentenceEmbeddingModel(
  (model): PeftModelForFeatureExtraction(
    (base_model): LoraModel(
      (model): BloomModel(
        (word_embeddings): Embedding(250880, 2048)
        (word_embeddings_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (h): ModuleList(
          (0-23): 24 x BloomBlock(
            (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
            (self_attention): BloomAttention(
              (query_key_value): lora.Linear(
                (base_layer): Linear(in_features=2048, out_features=6144, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2048, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=6144, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (dense): lora.Linear(
                (base_layer): Linear(in_features=2048, out_features=2048, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2048, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (attention_dropout): Dropout(p=0.0, inplace=False)
            )
            (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
            (mlp): BloomMLP(
              (dense_h_to_4h): lora.Linear(
                (base_layer): Linear(in_features=2048, out_features=8192, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2048, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=8192, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (gelu_impl): BloomGelu()
              (dense_4h_to_h): lora.Linear(
                (base_layer): Linear(in_features=8192, out_features=2048, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=8192, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
            )
          )
        )
        (ln_f): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      )
    )
  )
  (loss): CLIPLoss()
)
In [9]:
ids = batch.pop("ids")
output = sentence_embedding_model(**batch)
In [10]:
class SentenceEmbeddingLightningModule(pl.LightningModule):

    def __init__(self, sentence_embedding_model: SentenceEmbeddingModel):
        super().__init__()
        self.sentence_embedding_model = sentence_embedding_model

        # huggingface auto model loads model in eval mode. Latest version of
        # pytorch lightning no longer auto converts model to train mode during
        # trainer fit stage, end user need to explicitly call them
        # https://github.com/Lightning-AI/pytorch-lightning/issues/19467#issuecomment-1942741283
        self.sentence_embedding_model.train()

    def forward(self, **batch):
        return self.sentence_embedding_model(**batch)

    def training_step(self, batch, batch_idx):
        ids = batch.pop("ids")
        outputs = self(**batch)
        loss = outputs[0]
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def predict_step(self, batch, batch_idx):
        ids = batch.pop("ids")
        embeddings = self.sentence_embedding_model.encode(**batch)
        prediction_output = {"ids": ids, "embeddings": embeddings}
        return prediction_output

    def configure_optimizers(self):
        model = self.sentence_embedding_model

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": 0.001,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            }
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=0.0001)
        return optimizer
In [11]:
sentence_embedding_module = SentenceEmbeddingLightningModule(sentence_embedding_model)
trainer = pl.Trainer(
    accelerator="gpu",
    devices=-1,
    max_steps=2000,
    precision="16-mixed",
    # note, we purpose-fully disabled the progress bar to prevent flooding our notebook's console
    # in normal settings, we can/should definitely turn it on
    enable_progress_bar=False,
    log_every_n_steps=50,
)
dataloader_train = DataLoader(
    dataset_dict["train"],
    shuffle=True,
    collate_fn=data_collator,
    num_workers=2,
    batch_size=64,
    pin_memory=True,
)
trainer.fit(sentence_embedding_module, dataloader_train)
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                     | Type                   | Params
--------------------------------------------------------------------
0 | sentence_embedding_model | SentenceEmbeddingModel | 1.7 B 
--------------------------------------------------------------------
6.3 M     Trainable params
1.7 B     Non-trainable params
1.7 B     Total params
6,914.802 Total estimated model params size (MB)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
`Trainer.fit` stopped: `max_steps=2000` reached.

Evaluation

Evaluation process involves:

  • Generating embeddings for both distinct queries and products (corpus).
  • Retrieve top-k products using FAISS's flat index, i.e. exact cosine similarity.
  • Compute evaluation metrics, in this case recall@k.
In [12]:
# get the original model back for running inference
sentence_embedding_module.sentence_embedding_model.model = sentence_embedding_module.sentence_embedding_model.model.merge_and_unload()
sentence_embedding_module
Out[12]:
SentenceEmbeddingLightningModule(
  (sentence_embedding_model): SentenceEmbeddingModel(
    (model): BloomModel(
      (word_embeddings): Embedding(250880, 2048)
      (word_embeddings_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (h): ModuleList(
        (0-23): 24 x BloomBlock(
          (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (self_attention): BloomAttention(
            (query_key_value): Linear(in_features=2048, out_features=6144, bias=True)
            (dense): Linear(in_features=2048, out_features=2048, bias=True)
            (attention_dropout): Dropout(p=0.0, inplace=False)
          )
          (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (mlp): BloomMLP(
            (dense_h_to_4h): Linear(in_features=2048, out_features=8192, bias=True)
            (gelu_impl): BloomGelu()
            (dense_4h_to_h): Linear(in_features=8192, out_features=2048, bias=True)
          )
        )
      )
      (ln_f): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
    )
    (loss): CLIPLoss()
  )
)
In [13]:
def postprocess_predictions(predictions):
    prediction_outputs = {"ids": [], "embeddings": []}
    for prediction in predictions:
        prediction_outputs["ids"].extend(prediction["ids"])
        embeddings = [embedding for embedding in prediction["embeddings"].cpu().numpy()]
        prediction_outputs["embeddings"].extend(embeddings)
    
    return pd.DataFrame(prediction_outputs)
In [14]:
df_validation = dataset_dict["validation"].to_pandas()
df_validation
Out[14]:
query product_title product_id esci_label split relevance_label __index_level_0__
0 !qscreen fence without holes Zippity Outdoor Products ZP19026 Lightweight P... B07DHX8YH2 E test 1 34
1 !qscreen fence without holes ColourTree 4' x 50' Green Fence Privacy Screen... B07DS1YCRZ S test 1 35
2 !qscreen fence without holes ColourTree 6' x 50' Black Fence Privacy Screen... B07DS3J3MB S test 1 36
3 !qscreen fence without holes Sunnyglade 6 feet x 50 feet Privacy Screen Fen... B07MFP4PPQ E test 1 39
4 !qscreen fence without holes Amgo 6' x 50' Black Fence Privacy Screen Winds... B07R3TNQDM E test 1 41
... ... ... ... ... ... ... ...
286537 香奈儿 Chânél Chance Eau Tendre Eau de Toilette Women... B08181G6MP E test 1 2614586
286538 香奈儿 Steve Madden Designer 15 Inch Carry on Suitcas... B01HFH4DAI E test 1 2614587
286539 香奈儿 CHANEL Le Lift Creme Yeux, Black, 0.5 Ounce B00NIQGQAQ E test 1 2614588
286540 香奈儿 Chanel Bleu de Chanel Eau de Parfum Spray for ... B00NAGVL7W E test 1 2614589
286541 香奈儿 Chânél No. 5 by Chânél Eau De Parfum Premiere ... B081X6DRRT E test 1 2614593

286542 rows × 7 columns

In [15]:
df_query = df_validation[["query"]].drop_duplicates().reset_index(drop=True)
dataset_query = datasets.Dataset.from_pandas(df_query)
data_collator = DataCollatorForSentenceEmbedding(
    tokenizer,
    process_tower="tower_1",
    id_field="query"
)
dataloader = DataLoader(
    dataset_query,
    shuffle=False,
    collate_fn=data_collator,
    batch_size=64,
    pin_memory=True,
    num_workers=2
)
predictions = trainer.predict(sentence_embedding_module, dataloader)
df_query = postprocess_predictions(predictions)
df_query
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Out[15]:
ids embeddings
0 !qscreen fence without holes [-0.015719872, 0.0092682745, 0.0032925562, 0.0...
1 #1 black natural hair dye without ammonia or p... [0.007189093, 0.023743032, 0.015393866, 0.0183...
2 #1 rated resveratrol supplement without tea le... [-0.016837992, 0.028490836, 0.004562847, 0.022...
3 #10 envelopes without security tint [-0.021525824, -0.038583543, 0.0010891705, -0....
4 #10 standard no tint no window not self seal [-0.0038119196, -0.0015254373, 0.00934174, -0....
... ... ...
8951 zone mouthguard [0.01643445, -0.015690504, -0.007845199, 0.033...
8952 zoom groom for dogs [-0.051384337, 0.0015823826, 0.0002485972, -0....
8953 zozoville puzzle [0.00043908256, -0.013274448, -0.010934159, -0...
8954 سماعة gaming pc [-0.020317372, -0.004169517, -0.0020055668, 0....
8955 香奈儿 [-0.0122507345, -0.0038623791, 0.0058575706, 0...

8956 rows × 2 columns

In [16]:
df_product = df_validation[["product_id", "product_title"]].drop_duplicates().reset_index(drop=True)
dataset_product = datasets.Dataset.from_pandas(df_product)
data_collator = DataCollatorForSentenceEmbedding(
    tokenizer,
    process_tower="tower_2",
    id_field="product_id"
)
dataloader = DataLoader(
    dataset_product,
    shuffle=False,
    collate_fn=data_collator,
    batch_size=64,
    pin_memory=True,
    num_workers=2
)
predictions = trainer.predict(sentence_embedding_module, dataloader)
df_corpus = postprocess_predictions(predictions)
df_corpus
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Out[16]:
ids embeddings
0 B07DHX8YH2 [-0.0298094, 0.0066097863, -0.007526012, -0.00...
1 B07DS1YCRZ [-0.00046652087, -0.012415803, 0.008917227, -0...
2 B07DS3J3MB [-0.0028801567, -0.009671217, 0.011117947, -0....
3 B07MFP4PPQ [-0.02145726, -0.016728751, 0.0012048857, -0.0...
4 B07R3TNQDM [-0.020977724, -0.01857435, 0.010523445, -0.02...
... ... ...
132773 B01E7KBXWC [-0.016616581, -0.004221971, 0.009728117, -0.0...
132774 B08181G6MP [-0.016768515, 0.018336102, 0.010495669, -0.01...
132775 B01HFH4DAI [0.01899822, 0.010939811, -0.0010054795, -0.01...
132776 B00NIQGQAQ [-0.053966485, 0.004088956, 0.00047079526, 0.0...
132777 B081X6DRRT [-0.020010458, 0.0047987755, 0.003917635, -0.0...

132778 rows × 2 columns

In [17]:
index_ids = df_corpus["ids"].tolist()
index_embeddings = np.vstack(df_corpus["embeddings"]).astype(np.float32)
query_embeddings = np.vstack(df_query["embeddings"]).astype(np.float32)
dim = index_embeddings.shape[1]
In [18]:
topk = 10
knn_index = faiss.IndexFlatIP(dim)
knn_index = faiss.index_cpu_to_all_gpus(knn_index)
knn_index.add(index_embeddings)
knn_scores, knn_indices = knn_index.search(query_embeddings, k=topk)
knn_indices
Out[18]:
array([[129760,     10, 126665, ..., 129761, 130345,     11],
       [    22,     43,     36, ...,  31539,     40,     18],
       [    61,     79,     51, ...,     48,     76,     59],
       ...,
       [132754,   9639, 132755, ..., 132757,  98596, 106510],
       [ 92659,  54626,  76087, ...,  14052,  14060,  61385],
       [132772, 113343,  93151, ..., 113356, 132052, 132777]])
In [19]:
# convert knn retrieval result to {query -> [list of knn retrieved corpus id]}
knn_dict = {}
for query, knn_indices_per_row in zip(df_query["ids"], knn_indices):
    corpus_indices_per_row = [index_ids[index] for index in knn_indices_per_row]
    knn_dict[query] = corpus_indices_per_row
In [20]:
# convert validation dataset to {query -> [list of ground truth corpus id]}
eval_dict = {}
for query, product_id in zip(df_validation["query"], df_validation["product_id"]):
    if query not in eval_dict:
        eval_dict[query] = [product_id]
    else:
        eval_dict[query].append(product_id)
In [21]:
def compute_metrics(
    knn_dict,
    eval_dict,
    top_k: int
):
    recalls = []
    for query, knn_results in knn_dict.items():
        knn_set = set(knn_results)
        eval_set = set(eval_dict[query])
        numerator = len(knn_set.intersection(eval_set))
        denominator = min(len(eval_set), top_k)
        recall = numerator / denominator
        recalls.append(recall)

    avg_recall = np.mean(recalls)
    return avg_recall
In [22]:
compute_metrics(knn_dict, eval_dict, topk)
Out[22]:
0.3928341911425877

We conclude our article by offering some guidance when training with LoRA as well as decoder based retrieval models.

LoRA:

  • The most critical LoRA hyperparameter is how many LoRA adapters are used in total and LoRA on all linear transformer block layers are required to match full fine tuning's performance. Other parameters such as projection dimension $r$ doesn't affect performance much. i.e. It's more preferable to adapt more weight matrices than adapting a single type of weights with a larger rank.
  • When training with LoRA a lower learning rate as well as more steps might be required for matching full fine tuning's performance.
  • The effective of LoRA might be task dependent. Compared to full fine tuning, LoRA might stumble when ecountering more challenging tasks such as mathematical reasoning [5].

Personally, LoRA feels very much akin to matrix factorization, factorization machines family of methods with a twist.

Decoder retrieval models:

  • Exploring LLMs' usage in embedding have garnered quite some interest with good reason, e.g. Improving text embeddings with LLMs [11] showed that using LLMs (Mistral 7B) as an initial backbone using synthetic data along with some moderate amount of labeled text pairs is sufficient, foregoing the need for large amounts of text pairs to obtain high quality embeddings.
  • Keep in mind that apart from performance, there's also the cost of operating these large LLMs for embedding use case. This is from a inference speed perspective as well as storage (billion parameter scale LLM typically involves generating a larger embedding hidden dimension, 2048+)[6]

Reference

  • [1] PEFT Documentation: LoRA
  • [2] PEFT Documentation: LoRA for semantic similarity tasks
  • [3] Shopping Queries Dataset: A Large-Scale ESCI Benchmark for Improving Product Search
  • [4] LoRA From Scratch – Implement Low-Rank Adaptation for LLMs in PyTorch
  • [5] Fine-Tuning LLMs: LoRA or Full-Parameter? An in-depth Analysis with Llama 2
  • [6] OpenAI GPT-3 Text Embeddings - Really a new state-of-the-art in dense text embeddings?
  • [7] Xueguang Ma, Liang Wang, Nan Yang, Furu Wei, Jimmy Lin - Fine-Tuning LLaMA for Multi-Stage Text Retrieval (2023)
  • [8] Chandan K. Reddy, Lluís Màrquez, Fran Valero, Nikhil Rao, Hugo Zaragoza, Sambaran Bandyopadhyay, Arnab Biswas, Anlu Xing, Karthik Subbian - Shopping Queries Dataset: A Large-Scale ESCI Benchmark for Improving Product Search (2022)
  • [9] Edward J. Hu, Yelong Shen, et al. - LoRA: Low-Rank Adaptation of Large Language Models (2021)
  • [10] Arvind Neelakantan, Tao Xu, et al. - Text and Code Embeddings by Contrastive Pre-Training (2022)
  • [11] Liang Wang, Nan Yang, Furu Wei, et al. - Improving Text Embeddings with Large Language Models (2024)