None huggingface_torch_transformer
In [1]:
# code for loading the format for the notebook
import os

# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', 'notebook_format'))

from formats import load_style
load_style(css_style='custom2.css', plot_style=False)
Out[1]:
In [2]:
os.chdir(path)

# 1. magic for inline plot
# 2. magic to print version
# 3. magic so that the notebook will reload external python modules
# 4. magic to enable retina (high resolution) plots
# https://gist.github.com/minrk/3301035
%matplotlib inline
%load_ext watermark
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

import os
import math
import time
import torch
import random
import numpy as np
import pandas as pd
from datasets import load_dataset
from torch.utils.data import DataLoader
from tokenizers import ByteLevelBPETokenizer

%watermark -a 'Ethen' -d -t -v -p datasets,numpy,torch,tokenizers,transformers
Ethen 2021-05-17 03:21:35 

CPython 3.6.5
IPython 7.16.1

datasets 1.1.3
numpy 1.18.1
torch 1.7.0+cu101
tokenizers 0.10.1
transformers 4.3.0

Machine Translation with Huggingface Transformer

In this article we'll be leveraging Huggingface's Transformer on our machine translation task. The library provides thousands of pretrained models that we can use on our tasks. Apart from that, we'll also take a look at how to use its pre-built tokenizer and model architecture to train a model from scratch.

Data Preprocessing

We'll be using the Multi30k dataset to demonstrate using the transfomer model in a machine translation task. This German to English training dataset's size is around 29K, a moderate sized dataset so that we can get our results without waiting too long. We'll start off by downloading the raw dataset and extracting them. Feel free to swap this step with any other machine translation dataset.

In [3]:
import tarfile
import zipfile
import requests
import subprocess
from tqdm import tqdm
from urllib.parse import urlparse


def download_file(url: str, directory: str):
    """
    Download the file at ``url`` to ``directory``.
    Extract to the file content ``directory`` if the original file
    is a tar, tar.gz or zip file.

    Parameters
    ----------
    url : str
        url of the file.

    directory : str
        Directory to download the file.
    """
    response = requests.get(url, stream=True)
    response.raise_for_status()

    content_len = response.headers.get('Content-Length')
    total = int(content_len) if content_len is not None else 0

    os.makedirs(directory, exist_ok=True)
    file_name = get_file_name_from_url(url)
    file_path = os.path.join(directory, file_name)

    with tqdm(unit='B', total=total) as pbar, open(file_path, 'wb') as f:
        for chunk in response.iter_content(chunk_size=1024):
            if chunk:
                pbar.update(len(chunk))
                f.write(chunk)

    extract_compressed_file(file_path, directory)


def extract_compressed_file(compressed_file_path: str, directory: str):
    """
    Extract a compressed file to ``directory``. Supports zip, tar.gz, tgz,
    tar extensions.

    Parameters
    ----------
    compressed_file_path : str

    directory : str
        File will to extracted to this directory.
    """
    basename = os.path.basename(compressed_file_path)
    if 'zip' in basename:
        with zipfile.ZipFile(compressed_file_path, "r") as zip_f:
            zip_f.extractall(directory)
    elif 'tar.gz' in basename or 'tgz' in basename:
        with tarfile.open(compressed_file_path) as f:
            f.extractall(directory)


def get_file_name_from_url(url: str) -> str:
    """
    Return the file_name from a URL

    Parameters
    ----------
    url : str
        URL to extract file_name from

    Returns
    -------
    file_name : str
    """
    parse = urlparse(url)
    return os.path.basename(parse.path)
In [4]:
urls = [
    'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz',
    'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz',
    'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/mmt16_task1_test.tar.gz'
]
directory = 'multi30k'
for url in urls:
    download_file(url, directory)
100%|██████████| 1207136/1207136 [00:01<00:00, 761127.60B/s]
100%|██████████| 46329/46329 [00:00<00:00, 162237.27B/s]
100%|██████████| 43905/43905 [00:00<00:00, 146122.64B/s]

We print out the content in the data directory and some sample data.

In [5]:
!ls multi30k
mmt16_task1_test.tar.gz  test.en   train.en	    val.de  validation.tar.gz
test.de			 train.de  training.tar.gz  val.en
In [6]:
!head multi30k/train.de
Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.
Mehrere Männer mit Schutzhelmen bedienen ein Antriebsradsystem.
Ein kleines Mädchen klettert in ein Spielhaus aus Holz.
Ein Mann in einem blauen Hemd steht auf einer Leiter und putzt ein Fenster.
Zwei Männer stehen am Herd und bereiten Essen zu.
Ein Mann in grün hält eine Gitarre, während der andere Mann sein Hemd ansieht.
Ein Mann lächelt einen ausgestopften Löwen an.
Ein schickes Mädchen spricht mit dem Handy während sie langsam die Straße entlangschwebt.
Eine Frau mit einer großen Geldbörse geht an einem Tor vorbei.
Jungen tanzen mitten in der Nacht auf Pfosten.
In [7]:
!head multi30k/train.en
Two young, White males are outside near many bushes.
Several men in hard hats are operating a giant pulley system.
A little girl climbing into a wooden playhouse.
A man in a blue shirt is standing on a ladder cleaning a window.
Two men are at the stove preparing food.
A man in green holds a guitar while the other man observes his shirt.
A man is smiling at a stuffed lion
A trendy girl talking on her cellphone while gliding slowly down the street.
A woman with a large purse is walking by a gate.
Boys dancing on poles in the middle of the night.

The original dataset is splits the source and the target language into two separate files (e.g. train.de, train.en are the training dataset for German and English). This type of format is useful when we wish to train a tokenizer on top of the source or target language as we'll soon see.

On the other hand, having the source and target pair together in one single file makes it easier to load them in batches for training or evaluating our machine translation model. We'll create the paired dataset, and load the dataset. For loading the dataset, it will be helpful to have some basic understanding of Huggingface's dataset.

In [8]:
def create_translation_data(
    source_input_path: str,
    target_input_path: str,
    output_path: str,
    delimiter: str = '\t',
    encoding: str = 'utf-8'
):
    """
    Creates the paired source and target dataset from the separated ones.
    e.g. creates `train.tsv` from `train.de` and `train.en`
    """
    with open(source_input_path, encoding=encoding) as f_source_in, \
         open(target_input_path, encoding=encoding) as f_target_in, \
         open(output_path, 'w', encoding=encoding) as f_out:

        for source_raw in f_source_in:
            source_raw = source_raw.strip()
            target_raw = f_target_in.readline().strip()
            if source_raw and target_raw:
                output_line = source_raw + delimiter + target_raw + '\n'
                f_out.write(output_line)
In [9]:
source_lang = 'de'
target_lang = 'en'

data_files = {}
for split in ['train', 'val', 'test']:
    source_input_path = os.path.join(directory, f'{split}.{source_lang}')
    target_input_path = os.path.join(directory, f'{split}.{target_lang}')
    output_path = f'{split}.tsv'
    create_translation_data(source_input_path, target_input_path, output_path)
    data_files[split] = [output_path]

data_files
Out[9]:
{'train': ['train.tsv'], 'val': ['val.tsv'], 'test': ['test.tsv']}
In [10]:
dataset_dict = load_dataset(
    'csv',
    delimiter='\t',
    column_names=[source_lang, target_lang],
    data_files=data_files
)
dataset_dict
Using custom data configuration default
Downloading and preparing dataset csv/default-8e377772020fbbd4 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/mingyuliu/.cache/huggingface/datasets/csv/default-8e377772020fbbd4/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2...
HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))
Dataset csv downloaded and prepared to /home/mingyuliu/.cache/huggingface/datasets/csv/default-8e377772020fbbd4/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2. Subsequent calls will reuse this data.
Out[10]:
DatasetDict({
    train: Dataset({
        features: ['de', 'en'],
        num_rows: 29000
    })
    val: Dataset({
        features: ['de', 'en'],
        num_rows: 1014
    })
    test: Dataset({
        features: ['de', 'en'],
        num_rows: 1000
    })
})

We can acsess the split, and each record/pair with the following syntax.

In [11]:
dataset_dict['train'][0]
Out[11]:
{'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.',
 'en': 'Two young, White males are outside near many bushes.'}

Pretrained Model

To get started we'll use the MarianMT pretrained model translation model.

First thing we'll do is to load the pre-trained tokenizer, using the from_pretrained syntax. This ensures we get the tokenizer and vocabulary corresponding to the model architecture for this specific checkpoint.

In [12]:
from transformers import AutoTokenizer

model_checkpoint = "Helsinki-NLP/opus-mt-de-en"
pretrained_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
pretrained_tokenizer
Out[12]:
PreTrainedTokenizer(name_or_path='Helsinki-NLP/opus-mt-de-en', vocab_size=58101, model_max_len=512, is_fast=False, padding_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>'})

We can pass a single record, or a list of records to huggingface's tokenizer. Then depending on the model, we might see different keys in the dictionary returned. For example, here, we have:

  • input_ids: The tokenizer converted our raw input text into numerical ids.
  • attention_mask Mask to avoid performing attention on padded token ids. As we haven't yet performed the padding step, the numbers are all showing up as 1, indicating they are not masked.
In [13]:
pretrained_tokenizer(dataset_dict['train']['de'][0])
Out[13]:
{'input_ids': [3303, 5338, 17270, 2843, 70, 49, 14991, 5, 9, 1413, 10949, 14243, 3351, 3, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
In [14]:
# notice the last token id is 0, the end of sentence special token
pretrained_tokenizer.convert_ids_to_tokens(0)
Out[14]:
'</s>'
In [15]:
pretrained_tokenizer(dataset_dict['train']['de'][0:2])
Out[15]:
{'input_ids': [[3303, 5338, 17270, 2843, 70, 49, 14991, 5, 9, 1413, 10949, 14243, 3351, 3, 0], [20520, 2843, 30, 1235, 19116, 15, 14570, 53, 17992, 3013, 1947, 3, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

We can apply the tokenizers to our entire raw dataset, so this preprocessing will be a one time process. By passing the function to our dataset dict's map method, it will apply the same tokenizing step to all the splits in our data.

In [16]:
max_source_length = 128
max_target_length = 128
source_lang = "de"
target_lang = "en"


def batch_tokenize_fn(examples):
    """
    Generate the input_ids and labels field for huggingface dataset/dataset dict.
    
    Truncation is enabled, so we cap the sentence to the max length, padding will be done later
    in a data collator, so pad examples to the longest length in the batch and not the whole dataset.
    """
    sources = examples[source_lang]
    targets = examples[target_lang]
    model_inputs = pretrained_tokenizer(sources, max_length=max_source_length, truncation=True)

    # setup the tokenizer for targets,
    # huggingface expects the target tokenized ids to be stored in the labels field
    with pretrained_tokenizer.as_target_tokenizer():
        labels = pretrained_tokenizer(targets, max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs
In [17]:
tokenized_dataset_dict = dataset_dict.map(batch_tokenize_fn, batched=True, num_proc=8)
tokenized_dataset_dict
 
HBox(children=(FloatProgress(value=0.0, description='#6', max=4.0, style=ProgressStyle(description_width='init…
 
HBox(children=(FloatProgress(value=0.0, description='#4', max=4.0, style=ProgressStyle(description_width='init…
 
HBox(children=(FloatProgress(value=0.0, description='#2', max=4.0, style=ProgressStyle(description_width='init…
     
HBox(children=(FloatProgress(value=0.0, description='#5', max=4.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#7', max=4.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#0', max=4.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#3', max=4.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#1', max=4.0, style=ProgressStyle(description_width='init…







     
HBox(children=(FloatProgress(value=0.0, description='#7', max=1.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#0', max=1.0, style=ProgressStyle(description_width='init…
  
HBox(children=(FloatProgress(value=0.0, description='#1', max=1.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#2', max=1.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#6', max=1.0, style=ProgressStyle(description_width='init…


HBox(children=(FloatProgress(value=0.0, description='#3', max=1.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#4', max=1.0, style=ProgressStyle(description_width='init…



 
HBox(children=(FloatProgress(value=0.0, description='#5', max=1.0, style=ProgressStyle(description_width='init…
 
HBox(children=(FloatProgress(value=0.0, description='#1', max=1.0, style=ProgressStyle(description_width='init…
   
HBox(children=(FloatProgress(value=0.0, description='#0', max=1.0, style=ProgressStyle(description_width='init…
 
HBox(children=(FloatProgress(value=0.0, description='#4', max=1.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#7', max=1.0, style=ProgressStyle(description_width='init…

HBox(children=(FloatProgress(value=0.0, description='#6', max=1.0, style=ProgressStyle(description_width='init…


 
HBox(children=(FloatProgress(value=0.0, description='#3', max=1.0, style=ProgressStyle(description_width='init…
 
HBox(children=(FloatProgress(value=0.0, description='#5', max=1.0, style=ProgressStyle(description_width='init…
 
HBox(children=(FloatProgress(value=0.0, description='#2', max=1.0, style=ProgressStyle(description_width='init…

Out[17]:
DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'de', 'en', 'input_ids', 'labels'],
        num_rows: 29000
    })
    val: Dataset({
        features: ['attention_mask', 'de', 'en', 'input_ids', 'labels'],
        num_rows: 1014
    })
    test: Dataset({
        features: ['attention_mask', 'de', 'en', 'input_ids', 'labels'],
        num_rows: 1000
    })
})
In [18]:
# printing out the tokenized data, to check for the newly added fields
tokenized_dataset_dict['train'][0]
Out[18]:
{'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.',
 'en': 'Two young, White males are outside near many bushes.',
 'input_ids': [3303,
  5338,
  17270,
  2843,
  70,
  49,
  14991,
  5,
  9,
  1413,
  10949,
  14243,
  3351,
  3,
  0],
 'labels': [4386, 1296, 2, 3380, 25020, 48, 2060, 1656, 374, 45315, 3, 0]}

Having prepared our dataset, we'll load the pre-trained model. Similar to the tokenizer, we can use the .from_pretrained method, and specify a valid huggingface model.

In [19]:
from transformers import AutoModelForSeq2SeqLM

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

pretrained_model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(device)
print('# of parameters: ', pretrained_model.num_parameters())
pretrained_model
# of parameters:  74410496
Out[19]:
MarianMTModel(
  (model): MarianModel(
    (shared): Embedding(58101, 512, padding_idx=58100)
    (encoder): MarianEncoder(
      (embed_tokens): Embedding(58101, 512, padding_idx=58100)
      (embed_positions): MarianSinusoidalPositionalEmbedding(512, 512)
      (layers): ModuleList(
        (0): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (1): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (2): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (3): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (4): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (5): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (decoder): MarianDecoder(
      (embed_tokens): Embedding(58101, 512, padding_idx=58100)
      (embed_positions): MarianSinusoidalPositionalEmbedding(512, 512)
      (layers): ModuleList(
        (0): MarianDecoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (1): MarianDecoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (2): MarianDecoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (3): MarianDecoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (4): MarianDecoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (5): MarianDecoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): MarianAttention(
            (k_proj): Linear(in_features=512, out_features=512, bias=True)
            (v_proj): Linear(in_features=512, out_features=512, bias=True)
            (q_proj): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
          )
          (encoder_attn_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=512, out_features=2048, bias=True)
          (fc2): Linear(in_features=2048, out_features=512, bias=True)
          (final_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
  )
  (lm_head): Linear(in_features=512, out_features=58101, bias=False)
)

We can directly use this model to generate the translations, and eyeball the results.

In [20]:
def generate_translation(model, tokenizer, example):
    """print out the source, target and predicted raw text."""
    source = example[source_lang]
    target = example[target_lang]
    input_ids = example['input_ids']
    input_ids = torch.LongTensor(input_ids).view(1, -1).to(model.device)
    generated_ids = model.generate(input_ids)
    prediction = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    
    print('source: ', source)
    print('target: ', target)
    print('prediction: ', prediction)
In [21]:
example = tokenized_dataset_dict['train'][0]
generate_translation(pretrained_model, pretrained_tokenizer, example)
source:  Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.
target:  Two young, White males are outside near many bushes.
prediction:  Two young white men are outdoors near many bushes.
In [22]:
example = tokenized_dataset_dict['test'][0]
generate_translation(pretrained_model, pretrained_tokenizer, example)
source:  Ein Mann mit einem orangefarbenen Hut, der etwas anstarrt.
target:  A man in an orange hat starring at something.
prediction:  A man with an orange hat staring at something.

Training Model From Scratch

The next section shows the steps for training the model parameters from scratch. Instead of directly instantiating the model using .from_pretrained method. We use the .from_config method, where we specify the configurations for a particular model architecture. The configuration will be created using .from_pretrained, as well as updating some of the configuration hyper parameters, where we opted for a smaller model for faster iteration.

In [23]:
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)

config_params = {
    'd_model': 256,
    'decoder_layers': 3,
    'decoder_attention_heads': 8,
    'decoder_ffn_dim': 512,
    'encoder_layers': 6,
    'encoder_attention_heads': 8,
    'encoder_ffn_dim': 512,
    'max_length': 128,
    'max_position_embeddings': 128
}

model_checkpoint = "Helsinki-NLP/opus-mt-de-en"
config = AutoConfig.from_pretrained(model_checkpoint, **config_params)
config
Out[23]:
MarianConfig {
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "swish",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "MarianMTModel"
  ],
  "attention_dropout": 0.0,
  "bad_words_ids": [
    [
      58100
    ]
  ],
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 256,
  "decoder_attention_heads": 8,
  "decoder_ffn_dim": 512,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 3,
  "decoder_start_token_id": 58100,
  "dropout": 0.1,
  "encoder_attention_heads": 8,
  "encoder_ffn_dim": 512,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 0,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_length": 128,
  "max_position_embeddings": 128,
  "model_type": "marian",
  "normalize_before": false,
  "normalize_embedding": false,
  "num_beams": 4,
  "num_hidden_layers": 6,
  "pad_token_id": 58100,
  "scale_embedding": true,
  "static_position_embeddings": true,
  "transformers_version": "4.3.0",
  "use_cache": true,
  "vocab_size": 58101
}
In [24]:
model = AutoModelForSeq2SeqLM.from_config(config)
print('# of parameters: ', model.num_parameters())
model
# of parameters:  20474368
Out[24]:
MarianMTModel(
  (model): MarianModel(
    (shared): Embedding(58101, 256, padding_idx=58100)
    (encoder): MarianEncoder(
      (embed_tokens): Embedding(58101, 256, padding_idx=58100)
      (embed_positions): MarianSinusoidalPositionalEmbedding(128, 256)
      (layers): ModuleList(
        (0): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        )
        (1): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        )
        (2): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        )
        (3): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        )
        (4): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        )
        (5): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (decoder): MarianDecoder(
      (embed_tokens): Embedding(58101, 256, padding_idx=58100)
      (embed_positions): MarianSinusoidalPositionalEmbedding(128, 256)
      (layers): ModuleList(
        (0): MarianDecoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (encoder_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        )
        (1): MarianDecoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (encoder_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        )
        (2): MarianDecoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (encoder_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
  )
  (lm_head): Linear(in_features=256, out_features=58101, bias=False)
)

The huggingface library offers pre-built functionality to avoid writing the training logic from scratch. This step can be swapped out with other higher level trainer packages or even implementing our own logic. We setup the:

  • Seq2SeqTrainingArguments a class that contains all the attributes to customize the training. At the bare minimum, it requires one folder name, which will be used to save model checkpoint.
  • DataCollatorForSeq2Seq a helper class provided to batch our examples. Where the padding logic resides.
In [25]:
batch_size = 128
args = Seq2SeqTrainingArguments(
    output_dir="test-translation",
    evaluation_strategy="epoch",
    learning_rate=0.0005,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=20,
    load_best_model_at_end=True,
    predict_with_generate=True,
    remove_unused_columns=True,
    fp16=True
)

data_collator = DataCollatorForSeq2Seq(pretrained_tokenizer)

callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
In [26]:
trainer = Seq2SeqTrainer(
    model,
    args,
    data_collator=data_collator,
    train_dataset=tokenized_dataset_dict["train"],
    eval_dataset=tokenized_dataset_dict["val"],
    callbacks=callbacks
)

We can take a look at the batched examples. Understanding the output can be beneficial if we wish to customize the data collate function later.

  • attention_mask Padded tokens will be masked out with 0..
  • input_ids. Input ids are padded with the padding special tokens.
  • labels. By default -100 will be automatically ignored by PyTorch loss functions, hence we will use that particular id when padding our labels.
In [27]:
dataloader_train = trainer.get_train_dataloader()
batch = next(iter(dataloader_train))
batch
Out[27]:
{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'input_ids': tensor([[  246,  1155,     5,  ..., 58100, 58100, 58100],
         [  525,   788,     2,  ..., 58100, 58100, 58100],
         [  246,  2902, 18756,  ..., 58100, 58100, 58100],
         ...,
         [ 3303,  2843,     2,  ..., 58100, 58100, 58100],
         [  246,  2902,  1251,  ..., 58100, 58100, 58100],
         [  246,  5324,  8055,  ..., 58100, 58100, 58100]]),
 'labels': tensor([[   93,   175,     5,  ...,  -100,  -100,  -100],
         [   93,  2950,    19,  ...,  -100,  -100,  -100],
         [   93,  4040,  5074,  ...,  -100,  -100,  -100],
         ...,
         [ 4386,  1135, 25345,  ...,  -100,  -100,  -100],
         [   93,  4040,  2047,  ...,  -100,  -100,  -100],
         [   93,   839,  6799,  ...,  -100,  -100,  -100]])}
In [28]:
trainer_output = trainer.train()
trainer_output
[2497/4540 07:34 < 06:12, 5.49 it/s, Epoch 11/20]
Epoch Training Loss Validation Loss Runtime Samples Per Second
1 No log 3.701715 0.654800 1548.522000
2 No log 2.559382 0.669000 1515.684000
3 3.932700 2.114859 0.746600 1358.109000
4 3.932700 1.917877 0.940700 1077.914000
5 1.959900 1.787804 0.582700 1740.033000
6 1.959900 1.738364 0.497900 2036.748000
7 1.408800 1.711116 0.647000 1567.237000
8 1.408800 1.710581 0.572500 1771.124000
9 1.106000 1.723755 0.717300 1413.717000
10 1.106000 1.741999 0.558800 1814.496000
11 1.106000 1.731579 0.590900 1716.157000

Out[28]:
TrainOutput(global_step=2497, training_loss=1.8638566963950491, metrics={'train_runtime': 455.0706, 'train_samples_per_second': 9.976, 'total_flos': 1390184690368512, 'epoch': 11.0})

Similar to what we did before, we can use this model to generate the translations, and eyeball the results.

In [29]:
def generate_translation(model, tokenizer, example):
    """print out the source, target and predicted raw text."""
    source = example[source_lang]
    target = example[target_lang]
    input_ids = tokenizer(source)['input_ids']
    input_ids = torch.LongTensor(input_ids).view(1, -1).to(model.device)
    generated_ids = model.generate(input_ids)
    prediction = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    print('source: ', source)
    print('target: ', target)
    print('prediction: ', prediction)
In [30]:
example = dataset_dict['train'][0]
generate_translation(model, pretrained_tokenizer, example)
source:  Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.
target:  Two young, White males are outside near many bushes.
prediction:  Two young white men are outside near many bushes.
In [31]:
example = dataset_dict['test'][0]
generate_translation(model, pretrained_tokenizer, example)
source:  Ein Mann mit einem orangefarbenen Hut, der etwas anstarrt.
target:  A man in an orange hat starring at something.
prediction:  A man with an orange hat looking at something.

Training Tokenizer and Model From Scratch

From our raw pair, we need to use or train a tokenizer to convert them into numerical indices. Here we'll be training our tokenizer from scratch using Huggingface's tokenizer. Feel free to swap this step out with other tokenization procedures, what's important is to leave rooms for special tokens such as the init token that represents the beginning of a sentence, the end of sentence token that represents the end of a sentence, unknown token, and padding token that pads sentence batches into equivalent length.

In [32]:
# use only the training set to train our tokenizer
split = 'train'
source_input_path = os.path.join(directory, f'{split}.{source_lang}')
target_input_path = os.path.join(directory, f'{split}.{target_lang}')
print(source_input_path, target_input_path)
multi30k/train.de multi30k/train.en
In [33]:
bos_token = '<s>'
unk_token = '<unk>'
eos_token = '</s>'
pad_token = '<pad>'
special_tokens = [unk_token, bos_token, eos_token, pad_token]

tokenizer_params = {
    'min_frequency': 2,
    'vocab_size': 5000,
    'show_progress': False,
    'special_tokens': special_tokens
}

start_time = time.time()
source_tokenizer = ByteLevelBPETokenizer(lowercase=True)
source_tokenizer.train(source_input_path, **tokenizer_params)

target_tokenizer = ByteLevelBPETokenizer(lowercase=True)
target_tokenizer.train(target_input_path, **tokenizer_params)
end_time = time.time()

print('elapsed: ', end_time - start_time)
print('source vocab size: ', source_tokenizer.get_vocab_size())
print('target vocab size: ', target_tokenizer.get_vocab_size())
elapsed:  6.176600694656372
source vocab size:  5000
target vocab size:  5000

We'll perform this tokenization step for all our dataset up front, so we can do as little preprocessing as possible while feeding our dataset to model. Note that we do not perform the padding step at this stage.

In [34]:
pad_token_id = source_tokenizer.token_to_id(pad_token)
eos_token_id = source_tokenizer.token_to_id(eos_token)
In [35]:
def batch_encode_fn(examples):
    sources = examples[source_lang]
    targets = examples[target_lang]

    input_ids = [encoding.ids + [eos_token_id] for encoding in source_tokenizer.encode_batch(sources)]
    labels = [encoding.ids + [eos_token_id] for encoding in target_tokenizer.encode_batch(targets)]

    examples['input_ids'] = input_ids
    examples['labels'] = labels
    return examples
In [36]:
dataset_dict_encoded = dataset_dict.map(batch_encode_fn, batched=True, num_proc=8)
dataset_dict_encoded
        
HBox(children=(FloatProgress(value=0.0, description='#0', max=4.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#5', max=4.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#3', max=4.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#2', max=4.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#7', max=4.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#6', max=4.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#4', max=4.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#1', max=4.0, style=ProgressStyle(description_width='init…







        
HBox(children=(FloatProgress(value=0.0, description='#1', max=1.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#0', max=1.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#2', max=1.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#4', max=1.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#3', max=1.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#5', max=1.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#6', max=1.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#7', max=1.0, style=ProgressStyle(description_width='init…







        
HBox(children=(FloatProgress(value=0.0, description='#2', max=1.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#0', max=1.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#1', max=1.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#4', max=1.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#3', max=1.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#5', max=1.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#6', max=1.0, style=ProgressStyle(description_width='init…
HBox(children=(FloatProgress(value=0.0, description='#7', max=1.0, style=ProgressStyle(description_width='init…







Out[36]:
DatasetDict({
    train: Dataset({
        features: ['de', 'en', 'input_ids', 'labels'],
        num_rows: 29000
    })
    val: Dataset({
        features: ['de', 'en', 'input_ids', 'labels'],
        num_rows: 1014
    })
    test: Dataset({
        features: ['de', 'en', 'input_ids', 'labels'],
        num_rows: 1000
    })
})
In [37]:
dataset_train = dataset_dict_encoded['train']
dataset_train[0]
Out[37]:
{'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.',
 'en': 'Two young, White males are outside near many bushes.',
 'input_ids': [344,
  378,
  1191,
  413,
  649,
  349,
  660,
  281,
  327,
  726,
  1284,
  263,
  728,
  707,
  17,
  2],
 'labels': [336, 373, 15, 370, 2182, 321, 494, 557, 1203, 3158, 17, 2]}

Given the custom tokenizer, we can also custom our data collate class that does the padding for input and labels.

In [38]:
class Seq2SeqDataCollator:
    
    def __init__(
        self,
        max_length: int,
        pad_token_id: int,
        pad_label_token_id: int = -100
    ):
        self.max_length = max_length
        self.pad_token_id = pad_token_id
        self.pad_label_token_id = pad_label_token_id
        
    def __call__(self, batch):
        source_batch = []
        source_len = []
        target_batch = []
        target_len = []
        for example in batch:
            source = example['input_ids']
            source_len.append(len(source))
            source_batch.append(source)

            target = example['labels']
            target_len.append(len(target))
            target_batch.append(target)

        source_padded = self.process_encoded_text(source_batch, source_len, self.pad_token_id)
        target_padded = self.process_encoded_text(target_batch, target_len, self.pad_label_token_id)
        attention_mask = generate_attention_mask(source_padded, self.pad_token_id)
        return {
            'input_ids': source_padded,
            'labels': target_padded,
            'attention_mask': attention_mask
        }

    def process_encoded_text(self, sequences, sequences_len, pad_token_id):
        sequences_max_len = np.max(sequences_len)
        max_length = min(sequences_max_len, self.max_length)
        padded_sequences = pad_sequences(sequences, max_length, pad_token_id)
        return torch.LongTensor(padded_sequences)


def generate_attention_mask(input_ids, pad_token_id):
    return (input_ids != pad_token_id).long()

    
def pad_sequences(sequences, max_length, pad_token_id):
    """
    Pad the list of sequences (numerical token ids) to the same length.
    Sequence that are shorter than the specified ``max_len`` will be appended
    with the specified ``pad_token_id``. Those that are longer will be truncated.

    Parameters
    ----------
    sequences : list[int]
        List of numerical token ids.

    max_length : int
         Maximum length that all sequences will be truncated/padded to.

    pad_token_id : int
        Padding token index.

    Returns
    -------
    padded_sequences : 1d ndarray
    """
    num_samples = len(sequences)
    padded_sequences = np.full((num_samples, max_length), pad_token_id)
    for i, sequence in enumerate(sequences):
        sequence = np.array(sequence)[:max_length]
        padded_sequences[i, :len(sequence)] = sequence

    return padded_sequences

Given that we are using our own tokenizer instead of the pre-trained ones, we need to update a couple of other parameters in our config. The one that's worth pointing out is that this model starts generating with pad_token_id, that's why the decoder_start_token_id is the same as the pad_token_id.

Then rest of model training code should be the same as the ones in the previous section.

In [39]:
config_params = {
    'd_model': 256,
    'decoder_layers': 3,
    'decoder_attention_heads': 8,
    'decoder_ffn_dim': 512,
    'encoder_layers': 6,
    'encoder_attention_heads': 8,
    'encoder_ffn_dim': 512,
    'max_length': 128,
    'max_position_embeddings': 128,
    'eos_token_id': eos_token_id,
    'pad_token_id': pad_token_id,
    'decoder_start_token_id': pad_token_id,
    "bad_words_ids": [
        [
            pad_token_id
        ]
    ],
    'vocab_size': source_tokenizer.get_vocab_size()
}

model_config = AutoConfig.from_pretrained(model_checkpoint, **config_params)
model_config
Out[39]:
MarianConfig {
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "swish",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "MarianMTModel"
  ],
  "attention_dropout": 0.0,
  "bad_words_ids": [
    [
      3
    ]
  ],
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 256,
  "decoder_attention_heads": 8,
  "decoder_ffn_dim": 512,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 3,
  "decoder_start_token_id": 3,
  "dropout": 0.1,
  "encoder_attention_heads": 8,
  "encoder_ffn_dim": 512,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_length": 128,
  "max_position_embeddings": 128,
  "model_type": "marian",
  "normalize_before": false,
  "normalize_embedding": false,
  "num_beams": 4,
  "num_hidden_layers": 6,
  "pad_token_id": 3,
  "scale_embedding": true,
  "static_position_embeddings": true,
  "transformers_version": "4.3.0",
  "use_cache": true,
  "vocab_size": 5000
}
In [40]:
transformers_model = AutoModelForSeq2SeqLM.from_config(model_config)
print('# of parameters: ', transformers_model.num_parameters())
transformers_model
# of parameters:  6880512
Out[40]:
MarianMTModel(
  (model): MarianModel(
    (shared): Embedding(5000, 256, padding_idx=3)
    (encoder): MarianEncoder(
      (embed_tokens): Embedding(5000, 256, padding_idx=3)
      (embed_positions): MarianSinusoidalPositionalEmbedding(128, 256)
      (layers): ModuleList(
        (0): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        )
        (1): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        )
        (2): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        )
        (3): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        )
        (4): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        )
        (5): MarianEncoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (decoder): MarianDecoder(
      (embed_tokens): Embedding(5000, 256, padding_idx=3)
      (embed_positions): MarianSinusoidalPositionalEmbedding(128, 256)
      (layers): ModuleList(
        (0): MarianDecoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (encoder_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        )
        (1): MarianDecoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (encoder_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        )
        (2): MarianDecoderLayer(
          (self_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): MarianAttention(
            (k_proj): Linear(in_features=256, out_features=256, bias=True)
            (v_proj): Linear(in_features=256, out_features=256, bias=True)
            (q_proj): Linear(in_features=256, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (encoder_attn_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=256, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=256, bias=True)
          (final_layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
  )
  (lm_head): Linear(in_features=256, out_features=5000, bias=False)
)
In [41]:
batch_size = 128
args = Seq2SeqTrainingArguments(
    output_dir="test-translation",
    evaluation_strategy="epoch",
    learning_rate=0.0005,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=20,
    load_best_model_at_end=True,
    predict_with_generate=True,
    remove_unused_columns=True,
    fp16=True
)

data_collator = Seq2SeqDataCollator(model_config.max_length, pad_token_id)

callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
In [42]:
trainer = Seq2SeqTrainer(
    transformers_model,
    args,
    train_dataset=dataset_dict_encoded["train"],
    eval_dataset=dataset_dict_encoded["val"],
    data_collator=data_collator,
    callbacks=callbacks
)
In [43]:
dataloader_train = trainer.get_train_dataloader()
next(iter(dataloader_train))
Out[43]:
{'input_ids': tensor([[262, 294, 281,  ...,   3,   3,   3],
         [297, 318,  15,  ...,   3,   3,   3],
         [262, 386, 672,  ...,   3,   3,   3],
         ...,
         [344, 413,  15,  ...,   3,   3,   3],
         [262, 386, 546,  ...,   3,   3,   3],
         [262, 563, 378,  ...,   3,   3,   3]]),
 'labels': tensor([[  68,  292,  271,  ..., -100, -100, -100],
         [  68,  326,  293,  ..., -100, -100, -100],
         [  68,  376,  662,  ..., -100, -100, -100],
         ...,
         [ 336,  401,  560,  ..., -100, -100, -100],
         [  68,  376, 1130,  ..., -100, -100, -100],
         [  68,  505,  385,  ..., -100, -100, -100]]),
 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]])}
In [44]:
trainer_output = trainer.train()
trainer_output
[3405/4540 07:43 < 02:34, 7.34 it/s, Epoch 15/20]
Epoch Training Loss Validation Loss Runtime Samples Per Second
1 No log 3.598508 0.477500 2123.690000
2 No log 2.741445 0.630800 1607.404000
3 3.840200 2.316170 0.509800 1989.048000
4 3.840200 2.078891 0.713500 1421.158000
5 2.274600 1.941849 0.540400 1876.244000
6 2.274600 1.841438 0.608600 1666.216000
7 1.767100 1.781287 0.657200 1543.026000
8 1.767100 1.747373 0.599300 1691.906000
9 1.486600 1.719654 0.743000 1364.803000
10 1.486600 1.704974 0.617300 1642.552000
11 1.486600 1.701151 0.575000 1763.431000
12 1.294600 1.692111 0.519600 1951.319000
13 1.294600 1.693845 0.487700 2079.081000
14 1.136700 1.702049 0.508200 1995.281000
15 1.136700 1.706282 0.527000 1923.935000

Out[44]:
TrainOutput(global_step=3405, training_loss=1.8559266660357012, metrics={'train_runtime': 463.7701, 'train_samples_per_second': 9.789, 'total_flos': 643119585140736, 'epoch': 15.0})
In [45]:
def generate_translation(model, source_tokenizer, target_tokenizer, example):
    source = example[source_lang]
    target = example[target_lang]
    input_ids = source_tokenizer.encode(source).ids
    input_ids = torch.LongTensor(input_ids).view(1, -1).to(model.device)
    generated_ids = model.generate(input_ids)
    generated_ids = generated_ids[0].detach().cpu().numpy()

    prediction = target_tokenizer.decode(generated_ids)
    print('source: ', source)
    print('target: ', target)
    print('prediction: ', prediction)
In [46]:
example = dataset_dict['train'][0]
generate_translation(transformers_model, source_tokenizer, target_tokenizer, example)
source:  Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.
target:  Two young, White males are outside near many bushes.
prediction:  two young white men are outside near many bushes.
In [47]:
example = dataset_dict['test'][0]
generate_translation(transformers_model, source_tokenizer, target_tokenizer, example)
source:  Ein Mann mit einem orangefarbenen Hut, der etwas anstarrt.
target:  A man in an orange hat starring at something.
prediction:  a man in an orange hat who is looking at something.

Confirming saving and loading the model gives us identical predictions.

In [48]:
model_checkpoint = 'transformers_model'
transformers_model.save_pretrained(model_checkpoint)
In [49]:
transformers_model_loaded = transformers_model.from_pretrained(model_checkpoint).to(device)
In [50]:
example = dataset_dict['test'][0]
generate_translation(transformers_model_loaded, source_tokenizer, target_tokenizer, example)
source:  Ein Mann mit einem orangefarbenen Hut, der etwas anstarrt.
target:  A man in an orange hat starring at something.
prediction:  a man in an orange hat who is looking at something.

As the last step, we'll write a inferencing function that performs batch scoring on a given dataset. Here we generate the predictions and save it in a pandas dataframe along with the source and the target.

In [51]:
len(dataset_dict_encoded['test'])
Out[51]:
1000
In [58]:
# we use a different data collator then the one we used for training and evaluating model
# replace -100 in the labels with other special tokens during inferencing
# as we can't decode them.
data_collator = Seq2SeqDataCollator(model_config.max_length, pad_token_id, pad_token_id)
data_loader = DataLoader(dataset_dict_encoded['test'], collate_fn=data_collator, batch_size=64)
data_loader
Out[58]:
<torch.utils.data.dataloader.DataLoader at 0x7f1384a4e358>
In [59]:
start = time.time()
rows = []
for example in data_loader:
    input_ids = example['input_ids']
    generated_ids = transformers_model.generate(input_ids.to(transformers_model.device))
    generated_ids = generated_ids.detach().cpu().numpy()
    predictions = target_tokenizer.decode_batch(generated_ids)

    labels = example['labels'].detach().cpu().numpy()
    targets = target_tokenizer.decode_batch(labels)

    sources = source_tokenizer.decode_batch(input_ids.detach().cpu().numpy())
    for source, target, prediction in zip(sources, targets, predictions):
        row = [source, target, prediction]
        rows.append(row)

end = time.time()
print('elapsed: ', end - start)
df_rows = pd.DataFrame(rows, columns=['source', 'target', 'prediction'])
df_rows
elapsed:  12.964367628097534
Out[59]:
source target prediction
0 ein mann mit einem orangefarbenen hut, der etw... a man in an orange hat starring at something. a man in an orange hat looking at something.
1 ein boston terrier läuft über saftig-grünes gr... a boston terrier is running on lush green gras... a boststst skier runs in front of a white fenc...
2 ein mädchen in einem karateanzug bricht einen ... a girl in karate uniform breaking a stick with... a girl in a karate uniform gets a stick with a...
3 fünf leute in winterjacken und mit helmen steh... five people wearing winter jackets and helmets... five people wearing winter jackets and helmets...
4 leute reparieren das dach eines hauses. people are fixing the roof of a house. people are fixing the roof of a house.
... ... ... ...
995 marathonläufer laufen auf einer städtischen st... marathon runners are racing on a city street, ... a marathon runner running around a city street...
996 asiatische frau trägt einen sonnenhut beim fah... asian woman wearing a sunhat while riding a bike. asian woman wearing a sunhat riding a bicycle.
997 ein paar kinder sind im freien und spielen auf... some children are outside playing in the dirt ... a couple of children are outside playing on th...
998 ein älterer mann spielt ein videospiel. an older man is playing a video arcade game. an older man is playing a video game.
999 ein mädchen an einer küste mit einem berg im h... a girl at the shore of a beach with a mountain... a girl at a shore with a mountain in the backg...

1000 rows × 3 columns

Reference