None
# code for loading the format for the notebook
import os
# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', '..', 'notebook_format'))
from formats import load_style
load_style(css_style='custom2.css', plot_style=False)
os.chdir(path)
# 1. magic for inline plot
# 2. magic to print version
# 3. magic so that the notebook will reload external python modules
# 4. magic to enable retina (high resolution) plots
# https://gist.github.com/minrk/3301035
%matplotlib inline
%load_ext watermark
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'
import os
import torch
import random
import evaluate
import transformers
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from dataclasses import dataclass
from time import perf_counter
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset, disable_progress_bar
from transformers import (
AutoConfig,
AutoTokenizer,
AutoModelForSeq2SeqLM,
DataCollatorForSeq2Seq,
Seq2SeqTrainingArguments,
Seq2SeqTrainer,
EarlyStoppingCallback
)
# prevents progress bar and logging from flooding our document
disable_progress_bar()
%watermark -a 'Ethen' -d -u -v -p datasets,evaluate,torch,transformers,numpy,rouge_score,sacrebleu
In this article we'll be leveraging Huggingface's Transformer on our machine translation task as the library provides tons of pretrained models that we can either directly use or fine tune on our tasks.
Machine translation is a sequence to sequence task, where we have a source and a target sentence as our dataset, and in transformer era, we use an encoder-decoder model architecture to solve for these type of problems. This type of architecture can also be used for tasks such as summarization, generative question answering, etc.
We'll be using publicly available mT5 [11] pre-trained checkpoints, which is essentially a multi-lingual version of T5 (Text To Text Transfer Transformer) [10]. Quick recap: T5 is a also a pre-trained language model based on un-labeled data, the main distinction is it re-formulates all text based NLP problem, be it classification (e.g. GLUE or SuperGLUE benchmarks), transalation, summarization, question and answering, into a sequence to sequence setting that can be solved by an encoder-decoder architecture as shown in its original diagram.
Some key takeaways:
Apart from inheriting many of the properties from T5, some additional key results from mT5:
@dataclass
class Config:
cache_dir: str = "./translation"
data_dir: str = os.path.join(cache_dir, "wmt16")
source_lang: str = "de"
target_lang: str = "en"
batch_size: int = 16
num_workers: int = 4
seed: int = 42
max_source_length: int = 128
max_target_length: int = 128
lr: float = 0.0005
weight_decay: float = 0.01
epochs: int = 20
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_checkpoint: str = "google/mt5-small"
def __post_init__(self):
random.seed(self.seed)
np.random.seed(self.seed)
torch.manual_seed(self.seed)
torch.cuda.manual_seed_all(self.seed)
config = Config()
We'll be using Multi30k dataset [9], used in WMT16 (Multimodal Machine Translation 2016) conference, to demonstrate transfomer model in a machine translation task. This is a moderate sized German to English translation dataset, whose size is around 29K. That way, we can get our results without waiting too long. We'll start off by downloading the raw dataset and extracting them. Feel free to swap this step with any other machine translation dataset.
Utility scripts for creating this data reside here. If the original link for these datasets fails to load, use the alternative google drive link.
from translation_utils import download_file
# files are downloaded from
# http://www.statmt.org/wmt16/multimodal-task.html
urls = [
'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz',
'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz',
'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/mmt16_task1_test.tar.gz'
]
for url in urls:
download_file(url, config.data_dir)
!ls $config.data_dir
The original dataset splits source and target language into two separate files, e.g. train.de, train.en is our training dataset for German and English. This type of format is useful when we wish to train a tokenizer on top of the source or target language independently. On the other hand, having source and target pair together in one single file makes it easier to load them in batches for training or evaluating our machine translation model. We'll create a paired dataset, and load it into a dataset.
from translation_utils import create_translation_data
data_files = {}
for split in ["train", "val", "test"]:
source_input_path = os.path.join(config.data_dir, f"{split}.{config.source_lang}")
target_input_path = os.path.join(config.data_dir, f"{split}.{config.target_lang}")
output_path = os.path.join(config.cache_dir, f"{split}.tsv")
create_translation_data(source_input_path, target_input_path, output_path)
data_files[split] = [output_path]
data_files
dataset_dict = load_dataset(
"csv",
delimiter="\t",
column_names=[config.source_lang, config.target_lang],
data_files=data_files
)
dataset_dict
# We can acsess the split, and each record/pair with the following syntax
sample = dataset_dict["train"][0]
sample
Before we start building our model, we'll first go over how to quantitatively evaluate one. Evaluating a generative model's machine translation output is not as black and white as say an output for classification task, as given a source sentence, there might be multiple equally good target sentence. Popular automated metrics belong to the ROUGE and BLEU family, which both measures correspondence between a machine generated output to that of a human via word/token overlap.
rouge_score = evaluate.load("rouge", cache_dir=config.cache_dir)
bleu_score = evaluate.load("bleu", cache_dir=config.cache_dir)
sacrebleu_score = evaluate.load("sacrebleu", cache_dir=config.cache_dir)
BLEU (Bilingual Evaluation Understudy) [6] works first by comparing n-grams of the machine generated translation with n-grams of human provided reference translation, and counting the number of matches between the two, the more matches, the merrier. Its first component boils down to a modified precision metric:
\begin{align} \text{Modified Precision}=\frac{\text{capped number of overlapping words}}{\text{total number of words in generated summary}} \end{align}Where in the numerator, we will give each word credit only up to the maximum number of times that word appeared in its reference sentence.
e.g. Given the following example reference and candidate
Candidate: the the the cat mat
Reference: the cat is on the mat
We would have a precision of 3 / 5. As there are 5 total candidate words, making it our the denominator. As for numerator, "the", appeared 3 times in the candidate sentence but is then capped to 2 as, "the", only appeared 2 times in its reference sentence, and "cat" and "mat" both appeared 1 time.
BLEU introduces two additional adjustments on top of its modified precision cornerstone.
e.g. If we have a candidate, reference pair like below:
Candidate: of the
Reference: It is the practical guide for the army always to heed the directions of the party.
We would obtain a modified unigram precision of 2/2. Intuitively, this candidate translation's precision metric is a bit inflated due to its short nature. To account for this, sentence brevity penalty is introduced as the second component:
\begin{align} \text{BP} = \begin{cases} 1 & \text{if } c > r \\ \exp \big(1-\frac{r}{c}\big) & \text{if } c \leq r \end{cases} \end{align}where $c$ is the word length for the candidate sentences, and $r$ is the best match length for each candidate sentence in the corpus. e.g. given 2 references with lengths 12, 15 words and the candidate translation is 12 words, our brevity penalty will be 1 as the candidate's length is the same as any reference translation, and the closest reference's sentence length is termed "best match length". Note, this brevity penalty is computed by summing over the entire corpus to allow some freedom at sentence level.
Finally, putting it all together, BLEU score is defined as:
\begin{align} \text{BLEU} = \text{BP} \cdot \exp \bigg( \sum_{n=1}^{N} \frac{1}{N} \log p_n \bigg) \end{align}Where $p_n$ is the modified precision for $n$gram, and the second term on the right represents geometric mean for these $n$gram modified precision, where $N$ is typically set to 4 for 4-grams. The original paper is a great read for more in depth explanation.
generated_summary = "I absolutely loved reading the Hunger Games"
reference_summary = "I loved reading the Hunger Games"
# as we can see the metric returns necessary components including
# 'bleu': bleu score
# 'precisions': geometric mean of n-gram precisions
# 'brevity_penalty': brevity penalty
# we can confirm precision 1 is indeed 6 / 7
bleu_score.compute(
predictions=[generated_summary],
references=[reference_summary]
)
# result from another bleu implementation
sacrebleu_score.compute(
predictions=[generated_summary],
references=[reference_summary]
)
ROUGE (Recall-Oriented Understudy for Gisting Evaluation) [8] is also based on calculating overlap tokens between our system's generated summary versus reference summary (ground truth typically written by humans). In the original introduction of ROUGE, it was more focused on recall side of the picture compared to Bleu, which was more precision oriented. Nowadays, we'll commonly see it based on computing f1 score for the overlap. Where:
\begin{align} \text{Recall}=\frac{\text{number of overlapping words}}{\text{total number of words in reference summary}} \end{align}\begin{align} \text{Precision}=\frac{\text{number of overlapping words}}{\text{total number of words in generated summary}} \end{align}There are different variants of rouge score, the most popular ones being:
rouge-{n}
: computes the rouge score for matching n-grams. Most common ones includes rouge-1 for unigram overlap, and rouge-2 for bigram overlap.rougeL
: L here stands for longest common subsequence, i.e. the longest sequence of words that are not necessarily consecutive, but still in order that are shared between both. Without dependency on consecutive n-grams, this variant aim to capture sentence structure.# rouge 1:
# precision = 6 / 7
# recall = 6 / 6
# 2 * (precision * recall) / (recall + precision)
scores = rouge_score.compute(
predictions=[generated_summary],
references=[reference_summary],
rouge_types=["rouge1", "rouge2", "rougeL"]
)
scores
Important things to keep in mind:
Same as usual, we'll load our tokenizer and tokenize our raw text, which includes our source and target sentence.
Feel free to try out different model checkpoints. We'll pick a small checkpoint for rapid experimentation purpose as well as a multi-lingual one such as mT5 to cope with our multi-lingual dataset.
tokenizer = AutoTokenizer.from_pretrained(config.model_checkpoint, cache_dir=config.cache_dir)
model_name = config.model_checkpoint.split("/")[-1]
fine_tuned_model_checkpoint = os.path.join(
config.cache_dir,
f"{model_name}_{config.source_lang}-{config.target_lang}",
"checkpoint-4500"
)
if os.path.isdir(fine_tuned_model_checkpoint):
do_train = False
model = AutoModelForSeq2SeqLM.from_pretrained(fine_tuned_model_checkpoint, cache_dir=config.cache_dir)
else:
do_train = True
model = AutoModelForSeq2SeqLM.from_pretrained(config.model_checkpoint, cache_dir=config.cache_dir)
print("number of parameters:", model.num_parameters())
def batch_tokenize_fn(examples):
"""
Generate the input_ids and labels field for huggingface dataset/dataset dict.
Truncation is enabled where we cap the sentence to the max length. Padding will be done later
in a data collator, so we pad examples to the longest length within a mini-batch and not
the whole dataset.
"""
sources = examples[config.source_lang]
targets = examples[config.target_lang]
model_inputs = tokenizer(sources, max_length=config.max_source_length, truncation=True)
# setup the tokenizer for targets,
# huggingface expects the target tokenized ids to be stored in the labels field
# note, newer version of tokenizer supports a text_target argument, where we can create
# source and target sentences in one go
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=config.max_target_length, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
dataset_dict_tokenized = dataset_dict.map(
batch_tokenize_fn,
batched=True,
remove_columns=dataset_dict["train"].column_names
)
dataset_dict_tokenized
Data collator used for seq2seq model, DataCollatorForSeq2Seq
, needs to pad both input and labels. One thing to note is input ids need to padded with tokenizer's padding token, whereas labels need to padded with a sentinel value, -100
, which is the default value PyTorch uses for ignoring these indices during loss computation.
Apart from that, there's also another fieldspecific to encoder-decoder models called decoder_input_ids
. This field contains the input ids that will be fed to the decoder, and most of the time, they are shifted versions of labels with special tokens at the beginning. This is required for two main reasons: 1. to ensure decoder only sees the previous ground truth labels during training and not the current or future ones. 2. Introduce teacher forcing, where during training, our decoder always gets the ground-truth token in the next step, no matter what model's prediction are.
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
features = [dataset_dict_tokenized["train"][i] for i in range(2)]
output = data_collator(features)
output
For seq2seq models, we'll use seq2seq variants of TrainingArguments and Trainer. One of the most important difference is setting predict_with_generate=True
. Decoder performs inference by predicting tokens one by one, and this is implemented by the model's generate
method. Setting predict_with_generate=True tells Seq2SeqTrainer to use that method for evaluation.
model_name = config.model_checkpoint.split("/")[-1]
output_dir = os.path.join(config.cache_dir, f"{model_name}_{config.source_lang}-{config.target_lang}")
args = Seq2SeqTrainingArguments(
output_dir=output_dir,
evaluation_strategy="steps",
learning_rate=config.lr,
per_device_train_batch_size=config.batch_size,
per_device_eval_batch_size=config.batch_size,
weight_decay=config.weight_decay,
save_total_limit=2,
num_train_epochs=config.epochs,
predict_with_generate=True,
load_best_model_at_end=True,
greater_is_better=True,
metric_for_best_model="rougeL",
gradient_accumulation_steps=8,
do_train=do_train,
# careful when attempting to train t5 models on fp16 mixed precision,
# the model was trained on bfloat16 mixed precision, and mixing different mixed precision
# type might result in nan loss
# https://discuss.huggingface.co/t/mixed-precision-for-bfloat16-pretrained-models/5315
fp16=False
)
def compute_metrics(eval_pred):
"""
Compute rouge and bleu metrics for seq2seq model generated prediction.
tip: we can run trainer.predict on our eval/test dataset to see what a sample
eval_pred object would look like when implementing custom compute metrics function
"""
predictions, labels = eval_pred
# Decode generated summaries, which is in ids into text
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
# Replace -100 in the labels as we can't decode them
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
# Decode labels, a.k.a. reference summaries into text
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
result = rouge_score.compute(
predictions=decoded_preds,
references=decoded_labels,
rouge_types=["rouge1", "rouge2", "rougeL"]
)
score = sacrebleu_score.compute(
predictions=decoded_preds,
references=decoded_labels
)
result["sacrebleu"] = score["score"]
return {k: round(v, 4) for k, v in result.items()}
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
trainer = Seq2SeqTrainer(
model,
args,
train_dataset=dataset_dict_tokenized["train"],
eval_dataset=dataset_dict_tokenized["val"],
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback()]
)
# should take around 4117.78 seconds on a single V100 GPU
if trainer.args.do_train:
os.environ["DISABLE_MLFLOW_INTEGRATION"] = "TRUE"
t1_start = perf_counter()
train_output = trainer.train()
t1_stop = perf_counter()
print("Training elapsed time:", t1_stop - t1_start)
# saving the model which allows us to leverage
# .from_pretrained(model_path)
trainer.save_model(fine_tuned_model_checkpoint)
trainer.evaluate()
These numbers we got by fine-tuning a pre-trained MT5 model are pretty solid when comparing with, MarianMT, an already pretrained machine translation model, Helsinki-NLP/opus-mt-de-en
. Note, MarianMT model only has 74,410,496 parameters, which is a lot smaller compared to mt5-small's 300,176,768 parameters.
{'eval_loss': 1.1351912021636963, 'eval_rouge1': 0.7198, 'eval_rouge2': 0.4904, 'eval_rougeL': 0.6935, 'eval_sacrebleu': 40.2274, 'eval_runtime': 40.1401, 'eval_samples_per_second': 25.262, 'eval_steps_per_second': 1.594}
These numbers are computed using this script.
Apart from quantitative metric evaluation, we can also look at sample generated translations.
def generate_translation(model, tokenizer, example):
"""print out the source, target and predicted raw text."""
source = example[config.source_lang]
target = example[config.target_lang]
input_ids = tokenizer(source)["input_ids"]
input_ids = torch.LongTensor(input_ids).view(1, -1).to(model.device)
generated_ids = model.generate(input_ids, max_new_tokens=20)
prediction = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print('source: ', source)
print('target: ', target)
print('prediction: ', prediction)
example = dataset_dict['val'][0]
generate_translation(model, tokenizer, example)