None
# code for loading the format for the notebook
import os
# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', 'notebook_format'))
from formats import load_style
load_style(css_style='custom2.css', plot_style=False)
os.chdir(path)
# 1. magic for inline plot
# 2. magic to print version
# 3. magic so that the notebook will reload external python modules
# 4. magic to enable retina (high resolution) plots
# https://gist.github.com/minrk/3301035
%matplotlib inline
%load_ext watermark
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'
import os
import math
import time
import torch
import random
import numpy as np
import pandas as pd
from datasets import load_dataset
from torch.utils.data import DataLoader
from tokenizers import ByteLevelBPETokenizer
%watermark -a 'Ethen' -d -t -v -p datasets,numpy,torch,tokenizers,transformers
In this article we'll be leveraging Huggingface's Transformer on our machine translation task. The library provides thousands of pretrained models that we can use on our tasks. Apart from that, we'll also take a look at how to use its pre-built tokenizer and model architecture to train a model from scratch.
We'll be using the Multi30k dataset to demonstrate using the transfomer model in a machine translation task. This German to English training dataset's size is around 29K, a moderate sized dataset so that we can get our results without waiting too long. We'll start off by downloading the raw dataset and extracting them. Feel free to swap this step with any other machine translation dataset.
import tarfile
import zipfile
import requests
import subprocess
from tqdm import tqdm
from urllib.parse import urlparse
def download_file(url: str, directory: str):
"""
Download the file at ``url`` to ``directory``.
Extract to the file content ``directory`` if the original file
is a tar, tar.gz or zip file.
Parameters
----------
url : str
url of the file.
directory : str
Directory to download the file.
"""
response = requests.get(url, stream=True)
response.raise_for_status()
content_len = response.headers.get('Content-Length')
total = int(content_len) if content_len is not None else 0
os.makedirs(directory, exist_ok=True)
file_name = get_file_name_from_url(url)
file_path = os.path.join(directory, file_name)
with tqdm(unit='B', total=total) as pbar, open(file_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
pbar.update(len(chunk))
f.write(chunk)
extract_compressed_file(file_path, directory)
def extract_compressed_file(compressed_file_path: str, directory: str):
"""
Extract a compressed file to ``directory``. Supports zip, tar.gz, tgz,
tar extensions.
Parameters
----------
compressed_file_path : str
directory : str
File will to extracted to this directory.
"""
basename = os.path.basename(compressed_file_path)
if 'zip' in basename:
with zipfile.ZipFile(compressed_file_path, "r") as zip_f:
zip_f.extractall(directory)
elif 'tar.gz' in basename or 'tgz' in basename:
with tarfile.open(compressed_file_path) as f:
f.extractall(directory)
def get_file_name_from_url(url: str) -> str:
"""
Return the file_name from a URL
Parameters
----------
url : str
URL to extract file_name from
Returns
-------
file_name : str
"""
parse = urlparse(url)
return os.path.basename(parse.path)
urls = [
'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz',
'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz',
'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/mmt16_task1_test.tar.gz'
]
directory = 'multi30k'
for url in urls:
download_file(url, directory)
We print out the content in the data directory and some sample data.
!ls multi30k
!head multi30k/train.de
!head multi30k/train.en
The original dataset is splits the source and the target language into two separate files (e.g. train.de, train.en are the training dataset for German and English). This type of format is useful when we wish to train a tokenizer on top of the source or target language as we'll soon see.
On the other hand, having the source and target pair together in one single file makes it easier to load them in batches for training or evaluating our machine translation model. We'll create the paired dataset, and load the dataset. For loading the dataset, it will be helpful to have some basic understanding of Huggingface's dataset.
def create_translation_data(
source_input_path: str,
target_input_path: str,
output_path: str,
delimiter: str = '\t',
encoding: str = 'utf-8'
):
"""
Creates the paired source and target dataset from the separated ones.
e.g. creates `train.tsv` from `train.de` and `train.en`
"""
with open(source_input_path, encoding=encoding) as f_source_in, \
open(target_input_path, encoding=encoding) as f_target_in, \
open(output_path, 'w', encoding=encoding) as f_out:
for source_raw in f_source_in:
source_raw = source_raw.strip()
target_raw = f_target_in.readline().strip()
if source_raw and target_raw:
output_line = source_raw + delimiter + target_raw + '\n'
f_out.write(output_line)
source_lang = 'de'
target_lang = 'en'
data_files = {}
for split in ['train', 'val', 'test']:
source_input_path = os.path.join(directory, f'{split}.{source_lang}')
target_input_path = os.path.join(directory, f'{split}.{target_lang}')
output_path = f'{split}.tsv'
create_translation_data(source_input_path, target_input_path, output_path)
data_files[split] = [output_path]
data_files
dataset_dict = load_dataset(
'csv',
delimiter='\t',
column_names=[source_lang, target_lang],
data_files=data_files
)
dataset_dict
We can acsess the split, and each record/pair with the following syntax.
dataset_dict['train'][0]
To get started we'll use the MarianMT pretrained model translation model.
First thing we'll do is to load the pre-trained tokenizer, using the from_pretrained
syntax. This ensures we get the tokenizer and vocabulary corresponding to the model architecture for this specific checkpoint.
from transformers import AutoTokenizer
model_checkpoint = "Helsinki-NLP/opus-mt-de-en"
pretrained_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
pretrained_tokenizer
We can pass a single record, or a list of records to huggingface's tokenizer. Then depending on the model, we might see different keys in the dictionary returned. For example, here, we have:
input_ids
: The tokenizer converted our raw input text into numerical ids.attention_mask
Mask to avoid performing attention on padded token ids. As we haven't yet performed the padding step, the numbers are all showing up as 1, indicating they are not masked.pretrained_tokenizer(dataset_dict['train']['de'][0])
# notice the last token id is 0, the end of sentence special token
pretrained_tokenizer.convert_ids_to_tokens(0)
pretrained_tokenizer(dataset_dict['train']['de'][0:2])
We can apply the tokenizers to our entire raw dataset, so this preprocessing will be a one time process. By passing the function to our dataset dict's map
method, it will apply the same tokenizing step to all the splits in our data.
max_source_length = 128
max_target_length = 128
source_lang = "de"
target_lang = "en"
def batch_tokenize_fn(examples):
"""
Generate the input_ids and labels field for huggingface dataset/dataset dict.
Truncation is enabled, so we cap the sentence to the max length, padding will be done later
in a data collator, so pad examples to the longest length in the batch and not the whole dataset.
"""
sources = examples[source_lang]
targets = examples[target_lang]
model_inputs = pretrained_tokenizer(sources, max_length=max_source_length, truncation=True)
# setup the tokenizer for targets,
# huggingface expects the target tokenized ids to be stored in the labels field
with pretrained_tokenizer.as_target_tokenizer():
labels = pretrained_tokenizer(targets, max_length=max_target_length, truncation=True)
model_inputs["labels"] = labels["input_ids"]
return model_inputs
tokenized_dataset_dict = dataset_dict.map(batch_tokenize_fn, batched=True, num_proc=8)
tokenized_dataset_dict
# printing out the tokenized data, to check for the newly added fields
tokenized_dataset_dict['train'][0]
Having prepared our dataset, we'll load the pre-trained model. Similar to the tokenizer, we can use the .from_pretrained
method, and specify a valid huggingface model.
from transformers import AutoModelForSeq2SeqLM
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pretrained_model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device)
print('# of parameters: ', pretrained_model.num_parameters())
pretrained_model
We can directly use this model to generate
the translations, and eyeball the results.
def generate_translation(model, tokenizer, example):
"""print out the source, target and predicted raw text."""
source = example[source_lang]
target = example[target_lang]
input_ids = example['input_ids']
input_ids = torch.LongTensor(input_ids).view(1, -1).to(model.device)
generated_ids = model.generate(input_ids)
prediction = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print('source: ', source)
print('target: ', target)
print('prediction: ', prediction)
example = tokenized_dataset_dict['train'][0]
generate_translation(pretrained_model, pretrained_tokenizer, example)
example = tokenized_dataset_dict['test'][0]
generate_translation(pretrained_model, pretrained_tokenizer, example)
The next section shows the steps for training the model parameters from scratch. Instead of directly instantiating the model using .from_pretrained
method. We use the .from_config
method, where we specify the configurations for a particular model architecture. The configuration will be created using .from_pretrained
, as well as updating some of the configuration hyper parameters, where we opted for a smaller model for faster iteration.
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
DataCollatorForSeq2Seq,
EarlyStoppingCallback,
Seq2SeqTrainingArguments,
Seq2SeqTrainer
)
config_params = {
'd_model': 256,
'decoder_layers': 3,
'decoder_attention_heads': 8,
'decoder_ffn_dim': 512,
'encoder_layers': 6,
'encoder_attention_heads': 8,
'encoder_ffn_dim': 512,
'max_length': 128,
'max_position_embeddings': 128
}
model_checkpoint = "Helsinki-NLP/opus-mt-de-en"
config = AutoConfig.from_pretrained(model_checkpoint, **config_params)
config
model = AutoModelForSeq2SeqLM.from_config(config)
print('# of parameters: ', model.num_parameters())
model
The huggingface library offers pre-built functionality to avoid writing the training logic from scratch. This step can be swapped out with other higher level trainer packages or even implementing our own logic. We setup the:
Seq2SeqTrainingArguments
a class that contains all the attributes to customize the training. At the bare minimum, it requires one folder name, which will be used to save model checkpoint.DataCollatorForSeq2Seq
a helper class provided to batch our examples. Where the padding logic resides.batch_size = 128
args = Seq2SeqTrainingArguments(
output_dir="test-translation",
evaluation_strategy="epoch",
learning_rate=0.0005,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=0.01,
save_total_limit=3,
num_train_epochs=20,
load_best_model_at_end=True,
predict_with_generate=True,
remove_unused_columns=True,
fp16=True
)
data_collator = DataCollatorForSeq2Seq(pretrained_tokenizer)
callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
trainer = Seq2SeqTrainer(
model,
args,
data_collator=data_collator,
train_dataset=tokenized_dataset_dict["train"],
eval_dataset=tokenized_dataset_dict["val"],
callbacks=callbacks
)
We can take a look at the batched examples. Understanding the output can be beneficial if we wish to customize the data collate function later.
attention_mask
Padded tokens will be masked out with 0..input_ids
. Input ids are padded with the padding special tokens.labels
. By default -100 will be automatically ignored by PyTorch loss functions, hence we will use that particular id when padding our labels.dataloader_train = trainer.get_train_dataloader()
batch = next(iter(dataloader_train))
batch
trainer_output = trainer.train()
trainer_output
Similar to what we did before, we can use this model to generate
the translations, and eyeball the results.
def generate_translation(model, tokenizer, example):
"""print out the source, target and predicted raw text."""
source = example[source_lang]
target = example[target_lang]
input_ids = tokenizer(source)['input_ids']
input_ids = torch.LongTensor(input_ids).view(1, -1).to(model.device)
generated_ids = model.generate(input_ids)
prediction = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print('source: ', source)
print('target: ', target)
print('prediction: ', prediction)
example = dataset_dict['train'][0]
generate_translation(model, pretrained_tokenizer, example)
example = dataset_dict['test'][0]
generate_translation(model, pretrained_tokenizer, example)
From our raw pair, we need to use or train a tokenizer to convert them into numerical indices. Here we'll be training our tokenizer from scratch using Huggingface's tokenizer. Feel free to swap this step out with other tokenization procedures, what's important is to leave rooms for special tokens such as the init token that represents the beginning of a sentence, the end of sentence token that represents the end of a sentence, unknown token, and padding token that pads sentence batches into equivalent length.
# use only the training set to train our tokenizer
split = 'train'
source_input_path = os.path.join(directory, f'{split}.{source_lang}')
target_input_path = os.path.join(directory, f'{split}.{target_lang}')
print(source_input_path, target_input_path)
bos_token = '<s>'
unk_token = '<unk>'
eos_token = '</s>'
pad_token = '<pad>'
special_tokens = [unk_token, bos_token, eos_token, pad_token]
tokenizer_params = {
'min_frequency': 2,
'vocab_size': 5000,
'show_progress': False,
'special_tokens': special_tokens
}
start_time = time.time()
source_tokenizer = ByteLevelBPETokenizer(lowercase=True)
source_tokenizer.train(source_input_path, **tokenizer_params)
target_tokenizer = ByteLevelBPETokenizer(lowercase=True)
target_tokenizer.train(target_input_path, **tokenizer_params)
end_time = time.time()
print('elapsed: ', end_time - start_time)
print('source vocab size: ', source_tokenizer.get_vocab_size())
print('target vocab size: ', target_tokenizer.get_vocab_size())
We'll perform this tokenization step for all our dataset up front, so we can do as little preprocessing as possible while feeding our dataset to model. Note that we do not perform the padding step at this stage.
pad_token_id = source_tokenizer.token_to_id(pad_token)
eos_token_id = source_tokenizer.token_to_id(eos_token)
def batch_encode_fn(examples):
sources = examples[source_lang]
targets = examples[target_lang]
input_ids = [encoding.ids + [eos_token_id] for encoding in source_tokenizer.encode_batch(sources)]
labels = [encoding.ids + [eos_token_id] for encoding in target_tokenizer.encode_batch(targets)]
examples['input_ids'] = input_ids
examples['labels'] = labels
return examples
dataset_dict_encoded = dataset_dict.map(batch_encode_fn, batched=True, num_proc=8)
dataset_dict_encoded
dataset_train = dataset_dict_encoded['train']
dataset_train[0]
Given the custom tokenizer, we can also custom our data collate class that does the padding for input and labels.
class Seq2SeqDataCollator:
def __init__(
self,
max_length: int,
pad_token_id: int,
pad_label_token_id: int = -100
):
self.max_length = max_length
self.pad_token_id = pad_token_id
self.pad_label_token_id = pad_label_token_id
def __call__(self, batch):
source_batch = []
source_len = []
target_batch = []
target_len = []
for example in batch:
source = example['input_ids']
source_len.append(len(source))
source_batch.append(source)
target = example['labels']
target_len.append(len(target))
target_batch.append(target)
source_padded = self.process_encoded_text(source_batch, source_len, self.pad_token_id)
target_padded = self.process_encoded_text(target_batch, target_len, self.pad_label_token_id)
attention_mask = generate_attention_mask(source_padded, self.pad_token_id)
return {
'input_ids': source_padded,
'labels': target_padded,
'attention_mask': attention_mask
}
def process_encoded_text(self, sequences, sequences_len, pad_token_id):
sequences_max_len = np.max(sequences_len)
max_length = min(sequences_max_len, self.max_length)
padded_sequences = pad_sequences(sequences, max_length, pad_token_id)
return torch.LongTensor(padded_sequences)
def generate_attention_mask(input_ids, pad_token_id):
return (input_ids != pad_token_id).long()
def pad_sequences(sequences, max_length, pad_token_id):
"""
Pad the list of sequences (numerical token ids) to the same length.
Sequence that are shorter than the specified ``max_len`` will be appended
with the specified ``pad_token_id``. Those that are longer will be truncated.
Parameters
----------
sequences : list[int]
List of numerical token ids.
max_length : int
Maximum length that all sequences will be truncated/padded to.
pad_token_id : int
Padding token index.
Returns
-------
padded_sequences : 1d ndarray
"""
num_samples = len(sequences)
padded_sequences = np.full((num_samples, max_length), pad_token_id)
for i, sequence in enumerate(sequences):
sequence = np.array(sequence)[:max_length]
padded_sequences[i, :len(sequence)] = sequence
return padded_sequences
Given that we are using our own tokenizer instead of the pre-trained ones, we need to update a couple of other parameters in our config. The one that's worth pointing out is that this model starts generating with pad_token_id
, that's why the decoder_start_token_id
is the same as the pad_token_id
.
Then rest of model training code should be the same as the ones in the previous section.
config_params = {
'd_model': 256,
'decoder_layers': 3,
'decoder_attention_heads': 8,
'decoder_ffn_dim': 512,
'encoder_layers': 6,
'encoder_attention_heads': 8,
'encoder_ffn_dim': 512,
'max_length': 128,
'max_position_embeddings': 128,
'eos_token_id': eos_token_id,
'pad_token_id': pad_token_id,
'decoder_start_token_id': pad_token_id,
"bad_words_ids": [
[
pad_token_id
]
],
'vocab_size': source_tokenizer.get_vocab_size()
}
model_config = AutoConfig.from_pretrained(model_checkpoint, **config_params)
model_config
transformers_model = AutoModelForSeq2SeqLM.from_config(model_config)
print('# of parameters: ', transformers_model.num_parameters())
transformers_model
batch_size = 128
args = Seq2SeqTrainingArguments(
output_dir="test-translation",
evaluation_strategy="epoch",
learning_rate=0.0005,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
weight_decay=0.01,
save_total_limit=3,
num_train_epochs=20,
load_best_model_at_end=True,
predict_with_generate=True,
remove_unused_columns=True,
fp16=True
)
data_collator = Seq2SeqDataCollator(model_config.max_length, pad_token_id)
callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
trainer = Seq2SeqTrainer(
transformers_model,
args,
train_dataset=dataset_dict_encoded["train"],
eval_dataset=dataset_dict_encoded["val"],
data_collator=data_collator,
callbacks=callbacks
)
dataloader_train = trainer.get_train_dataloader()
next(iter(dataloader_train))
trainer_output = trainer.train()
trainer_output
def generate_translation(model, source_tokenizer, target_tokenizer, example):
source = example[source_lang]
target = example[target_lang]
input_ids = source_tokenizer.encode(source).ids
input_ids = torch.LongTensor(input_ids).view(1, -1).to(model.device)
generated_ids = model.generate(input_ids)
generated_ids = generated_ids[0].detach().cpu().numpy()
prediction = target_tokenizer.decode(generated_ids)
print('source: ', source)
print('target: ', target)
print('prediction: ', prediction)
example = dataset_dict['train'][0]
generate_translation(transformers_model, source_tokenizer, target_tokenizer, example)
example = dataset_dict['test'][0]
generate_translation(transformers_model, source_tokenizer, target_tokenizer, example)
Confirming saving and loading the model gives us identical predictions.
model_checkpoint = 'transformers_model'
transformers_model.save_pretrained(model_checkpoint)
transformers_model_loaded = transformers_model.from_pretrained(model_checkpoint).to(device)
example = dataset_dict['test'][0]
generate_translation(transformers_model_loaded, source_tokenizer, target_tokenizer, example)
As the last step, we'll write a inferencing function that performs batch scoring on a given dataset. Here we generate the predictions and save it in a pandas dataframe along with the source and the target.
len(dataset_dict_encoded['test'])
# we use a different data collator then the one we used for training and evaluating model
# replace -100 in the labels with other special tokens during inferencing
# as we can't decode them.
data_collator = Seq2SeqDataCollator(model_config.max_length, pad_token_id, pad_token_id)
data_loader = DataLoader(dataset_dict_encoded['test'], collate_fn=data_collator, batch_size=64)
data_loader
start = time.time()
rows = []
for example in data_loader:
input_ids = example['input_ids']
generated_ids = transformers_model.generate(input_ids.to(transformers_model.device))
generated_ids = generated_ids.detach().cpu().numpy()
predictions = target_tokenizer.decode_batch(generated_ids)
labels = example['labels'].detach().cpu().numpy()
targets = target_tokenizer.decode_batch(labels)
sources = source_tokenizer.decode_batch(input_ids.detach().cpu().numpy())
for source, target, prediction in zip(sources, targets, predictions):
row = [source, target, prediction]
rows.append(row)
end = time.time()
print('elapsed: ', end - start)
df_rows = pd.DataFrame(rows, columns=['source', 'target', 'prediction'])
df_rows