None torch_transformer
In [1]:
# code for loading the format for the notebook
import os

# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', 'notebook_format'))

from formats import load_style
load_style(css_style='custom2.css', plot_style=False)
Out[1]:
In [2]:
os.chdir(path)

# 1. magic for inline plot
# 2. magic to print version
# 3. magic so that the notebook will reload external python modules
# 4. magic to enable retina (high resolution) plots
# https://gist.github.com/minrk/3301035
%matplotlib inline
%load_ext watermark
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

import os
import math
import time
import torch
import random
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tokenizers import ByteLevelBPETokenizer
from datasets import load_dataset, disable_progress_bar

# prevents progress bar and logging from flooding our document
disable_progress_bar()

%watermark -a 'Ethen' -d -t -v -u -p datasets,numpy,torch,tokenizers
/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
Author: Ethen

Last updated: 2023-09-11 22:48:58

Python implementation: CPython
Python version       : 3.10.6
IPython version      : 8.13.2

datasets  : 2.14.4
numpy     : 1.23.2
torch     : 2.0.1
tokenizers: 0.13.3

Transformer

Seq2Seq based machine translation system usually comprises of two main components, an encoder that encodes in source sentence into context vectors and a decoder that decodes the context vectors into target sentence, transformer model is no different in this regards. Reasons to their growing popularity at the time of writing this document are primarily due to self attention layers and parallel computation.

Previous RNN based encoder and decoder has a constraint of sequential computation. A hidden state at time $t$ in a recurrent layer, has only seen tokens $x_t$ and all the tokens before it, even though this gives us the benefit of modeling long dependencies, it hinders training speed as we can't process the next time step until we finish processing the current one. Transformer model aims to mitigate this issue by solely relying on attention mechanism, where each context vector produced by a transformer model has seen all tokens at all positions within the input sequence. In other words, instead of compressing the entire source sentence, $X = (x_1, ... , x_n)$ into a single context vector, $z$, it produces a sequence of context vectors, $Z = (z_1, ... , z_n)$ in one parallel computation. We'll get to the details of attention mechanism, self attention, that's used throughout the Transformer model in later sections. One important thing to note here is that breakthrough of this model is not due to invention of the attention mechansim, as this concept existed well before. The highlight here is we can build a highly performant model with attention mechanism in isolation, i.e. without the use of recurrent (RNN) or convolutional (CNN) neural networks in the mix.

In this article, we will be implementing Transformer module from the famous Attention is all you need paper [9]. This implementation's structure is largely based on [1]. With the primary difference that we'll be using huggingface's dataset instead of torchtext for data loading, as well as show casing how to implement Transformer module leveraging PyTorch built in Transformer Encoder and Decoder block.

Data Preprocessing

We'll be using the Multi30k dataset to demonstrate using the transfomer model in a machine translation task. This German to English training dataset's size is around 29K. We'll start off by downloading the raw dataset and extracting them. Feel free to swap this step with any other machine translation dataset. If the original link for these datasets fails to load, use this alternative google drive link.

In [3]:
import tarfile
import zipfile
import requests
import subprocess
from tqdm import tqdm
from urllib.parse import urlparse


def download_file(url: str, directory: str):
    """
    Download the file at ``url`` to ``directory``.
    Extract to the file content ``directory`` if the original file
    is a tar, tar.gz or zip file.

    Parameters
    ----------
    url : str
        url of the file.

    directory : str
        Directory to download the file.
    """
    response = requests.get(url, stream=True)
    response.raise_for_status()

    content_len = response.headers.get('Content-Length')
    total = int(content_len) if content_len is not None else 0

    os.makedirs(directory, exist_ok=True)
    file_name = get_file_name_from_url(url)
    file_path = os.path.join(directory, file_name)

    with tqdm(unit='B', total=total) as pbar, open(file_path, 'wb') as f:
        for chunk in response.iter_content(chunk_size=1024):
            if chunk:
                pbar.update(len(chunk))
                f.write(chunk)

    extract_compressed_file(file_path, directory)


def extract_compressed_file(compressed_file_path: str, directory: str):
    """
    Extract a compressed file to ``directory``. Supports zip, tar.gz, tgz,
    tar extensions.

    Parameters
    ----------
    compressed_file_path : str

    directory : str
        File will to extracted to this directory.
    """
    basename = os.path.basename(compressed_file_path)
    if 'zip' in basename:
        with zipfile.ZipFile(compressed_file_path, "r") as zip_f:
            zip_f.extractall(directory)
    elif 'tar.gz' in basename or 'tgz' in basename:
        with tarfile.open(compressed_file_path) as f:
            f.extractall(directory)


def get_file_name_from_url(url: str) -> str:
    """
    Return the file_name from a URL

    Parameters
    ----------
    url : str
        URL to extract file_name from

    Returns
    -------
    file_name : str
    """
    parse = urlparse(url)
    return os.path.basename(parse.path)
In [4]:
urls = [
    'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz',
    'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz',
    'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/mmt16_task1_test.tar.gz'
]
directory = './translation/wmt16'
for url in urls:
    download_file(url, directory)

We print out the content in the data directory and some sample data.

In [5]:
!ls $directory
mmt16_task1_test.tar.gz  test.en   train.en	    val.de  validation.tar.gz
test.de			 train.de  training.tar.gz  val.en
In [6]:
!head $directory/train.de
Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.
Mehrere Männer mit Schutzhelmen bedienen ein Antriebsradsystem.
Ein kleines Mädchen klettert in ein Spielhaus aus Holz.
Ein Mann in einem blauen Hemd steht auf einer Leiter und putzt ein Fenster.
Zwei Männer stehen am Herd und bereiten Essen zu.
Ein Mann in grün hält eine Gitarre, während der andere Mann sein Hemd ansieht.
Ein Mann lächelt einen ausgestopften Löwen an.
Ein schickes Mädchen spricht mit dem Handy während sie langsam die Straße entlangschwebt.
Eine Frau mit einer großen Geldbörse geht an einem Tor vorbei.
Jungen tanzen mitten in der Nacht auf Pfosten.
In [7]:
!head $directory/train.en
Two young, White males are outside near many bushes.
Several men in hard hats are operating a giant pulley system.
A little girl climbing into a wooden playhouse.
A man in a blue shirt is standing on a ladder cleaning a window.
Two men are at the stove preparing food.
A man in green holds a guitar while the other man observes his shirt.
A man is smiling at a stuffed lion
A trendy girl talking on her cellphone while gliding slowly down the street.
A woman with a large purse is walking by a gate.
Boys dancing on poles in the middle of the night.

The original dataset is splits the source and the target language into two separate files (e.g. train.de, train.en are the training dataset for German and English). This type of format is useful when we wish to train a tokenizer on top of the source or target language as we'll soon see.

On the other hand, having the source and target pair together in one single file makes it easier to load them in batches for training or evaluating our machine translation model. We'll create the paired dataset, and load the dataset. For loading the dataset, it will be helpful to have some basic understanding of Huggingface's dataset.

In [8]:
def create_translation_data(
    source_input_path: str,
    target_input_path: str,
    output_path: str,
    delimiter: str = '\t',
    encoding: str = 'utf-8'
):
    """
    Creates the paired source and target dataset from the separated ones.
    e.g. creates `train.tsv` from `train.de` and `train.en`
    """
    with open(source_input_path, encoding=encoding) as f_source_in, \
         open(target_input_path, encoding=encoding) as f_target_in, \
         open(output_path, 'w', encoding=encoding) as f_out:

        for source_raw in f_source_in:
            source_raw = source_raw.strip()
            target_raw = f_target_in.readline().strip()
            if source_raw and target_raw:
                output_line = source_raw + delimiter + target_raw + '\n'
                f_out.write(output_line)
In [9]:
source_lang = 'de'
target_lang = 'en'

data_files = {}
for split in ['train', 'val', 'test']:
    source_input_path = os.path.join(directory, f'{split}.{source_lang}')
    target_input_path = os.path.join(directory, f'{split}.{target_lang}')
    output_path = f'{split}.tsv'
    create_translation_data(source_input_path, target_input_path, output_path)
    data_files[split] = [output_path]

data_files
Out[9]:
{'train': ['train.tsv'], 'val': ['val.tsv'], 'test': ['test.tsv']}
In [10]:
dataset_dict = load_dataset(
    'csv',
    delimiter='\t',
    column_names=[source_lang, target_lang],
    data_files=data_files
)
dataset_dict
Out[10]:
DatasetDict({
    train: Dataset({
        features: ['de', 'en'],
        num_rows: 29000
    })
    val: Dataset({
        features: ['de', 'en'],
        num_rows: 1014
    })
    test: Dataset({
        features: ['de', 'en'],
        num_rows: 1000
    })
})

We can access each split, and record/pair with the following syntax.

In [11]:
dataset_dict['train'][0]
Out[11]:
{'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.',
 'en': 'Two young, White males are outside near many bushes.'}

From our raw pair, we need to use or train a tokenizer to convert them into numerical indices. Here we'll be training our tokenizer from scratch using Huggingface's tokenizer. Feel free to swap this step out with other tokenization procedures, what's important is to leave rooms for special tokens such as the init token that represents the start of a sentence, the end of sentence token that represents the end of a sentence, and padding token that pads sentence batches into equivalent length.

In [12]:
# use only the training set to train our tokenizer
split = 'train'
source_input_path = os.path.join(directory, f'{split}.{source_lang}')
target_input_path = os.path.join(directory, f'{split}.{target_lang}')
print(source_input_path, target_input_path)
./translation/wmt16/train.de ./translation/wmt16/train.en
In [13]:
init_token = '<sos>'
eos_token = '<eos>'
pad_token = '<pad>'

tokenizer_params = {
    'min_frequency': 2,
    'vocab_size': 5000,
    'show_progress': False,
    'special_tokens': [init_token, eos_token, pad_token]
}

start_time = time.time()
source_tokenizer = ByteLevelBPETokenizer(lowercase=True)
source_tokenizer.train(source_input_path, **tokenizer_params)

target_tokenizer = ByteLevelBPETokenizer(lowercase=True)
target_tokenizer.train(target_input_path, **tokenizer_params)
end_time = time.time()

print('elapsed: ', end_time - start_time)
print('source vocab size: ', source_tokenizer.get_vocab_size())
print('target vocab size: ', target_tokenizer.get_vocab_size())
elapsed:  1.4205942153930664
source vocab size:  5000
target vocab size:  5000
In [14]:
source_eos_idx = source_tokenizer.token_to_id(eos_token)
target_eos_idx = target_tokenizer.token_to_id(eos_token)

source_init_idx = source_tokenizer.token_to_id(init_token)
target_init_idx = target_tokenizer.token_to_id(init_token)

We'll perform this tokenization step for all our dataset up front, so we can do as little preprocessing as possible while feeding our dataset to model. Note that we do not perform the padding step at this stage.

In [15]:
def encode(example):
    """
    Encode the raw text into numerical token ids. Creating two new fields
    ``source_ids`` and ``target_ids``.
    Also append the init token and prepend eos token to the sentence.
    """
    source_raw = example[source_lang]
    target_raw = example[target_lang]
    source_encoded = source_tokenizer.encode(source_raw).ids
    source_encoded = [source_init_idx] + source_encoded + [source_eos_idx]
    target_encoded = target_tokenizer.encode(target_raw).ids
    target_encoded = [target_init_idx] + target_encoded + [target_eos_idx]
    example['source_ids'] = source_encoded
    example['target_ids'] = target_encoded
    return example


start_time = time.time()
dataset_dict_encoded = dataset_dict.map(encode, num_proc=8)
end_time = time.time()
print('elapsed: ', end_time - start_time)

dataset_dict_encoded
elapsed:  2.2638769149780273
Out[15]:
DatasetDict({
    train: Dataset({
        features: ['de', 'en', 'source_ids', 'target_ids'],
        num_rows: 29000
    })
    val: Dataset({
        features: ['de', 'en', 'source_ids', 'target_ids'],
        num_rows: 1014
    })
    test: Dataset({
        features: ['de', 'en', 'source_ids', 'target_ids'],
        num_rows: 1000
    })
})
In [16]:
dataset_train = dataset_dict_encoded['train']
dataset_train[0]
Out[16]:
{'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.',
 'en': 'Two young, White males are outside near many bushes.',
 'source_ids': [0,
  343,
  377,
  1190,
  412,
  648,
  348,
  659,
  280,
  326,
  725,
  1283,
  262,
  727,
  706,
  16,
  1],
 'target_ids': [0, 335, 372, 14, 369, 2181, 320, 493, 556, 1202, 3157, 16, 1]}

The final step for our data preprocessing step is to prepare the DataLoader, which prepares batches of tokenized ids for our model. The customized collate function performs the batching as well as padding.

In [17]:
class TranslationPairCollate:

    def __init__(self, max_len, pad_idx, device, percentage=100):
        self.device = device
        self.max_len = max_len
        self.pad_idx = pad_idx
        self.percentage = percentage

    def __call__(self, batch):
        source_batch = []
        source_len = []
        target_batch = []
        target_len = []
        for example in batch:
            source = example['source_ids']
            source_len.append(len(source))
            source_batch.append(source)

            target = example['target_ids']
            target_len.append(len(target))
            target_batch.append(target)

        source_padded = self.process_encoded_text(source_batch, source_len, self.max_len, self.pad_idx)
        target_padded = self.process_encoded_text(target_batch, target_len, self.max_len, self.pad_idx)
        return source_padded, target_padded

    def process_encoded_text(self, sequences, sequences_len, max_len, pad_idx):
        sequences_len_percentile = int(np.percentile(sequences_len, self.percentage))
        max_len = min(sequences_len_percentile, max_len)
        padded_sequences = pad_sequences(sequences, max_len, pad_idx)
        return torch.LongTensor(padded_sequences)


def pad_sequences(sequences, max_len, pad_idx):
    """
    Pad the list of sequences (numerical token ids) to the same length.
    Sequence that are shorter than the specified ``max_len`` will be appended
    with the specified ``pad_idx``. Those that are longer will be truncated.

    Parameters
    ----------
    sequences : list[int]
        List of numerical token ids.

    max_len : int
         Maximum length of all sequences.

    pad_idx : int
        Padding index.

    Returns
    -------
    padded_sequences : 1d ndarray
    """
    num_samples = len(sequences)
    padded_sequences = np.full((num_samples, max_len), pad_idx)
    for i, sequence in enumerate(sequences):
        sequence = np.array(sequence)[:max_len]
        padded_sequences[i, :len(sequence)] = sequence

    return padded_sequences
In [18]:
max_len = 100
batch_size = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

pad_idx = source_tokenizer.token_to_id(pad_token)
translation_pair_collate_fn = TranslationPairCollate(max_len, pad_idx, device)

data_loader_params = {
    'batch_size': batch_size,
    'collate_fn': translation_pair_collate_fn,
    'pin_memory': True
}

dataloader_train = DataLoader(dataset_train, **data_loader_params)

# we can print out 1 batch of source and target
source, target = next(iter(dataloader_train))
source, target
Out[18]:
(tensor([[  0, 343, 377,  ...,   2,   2,   2],
         [  0, 640, 412,  ...,   2,   2,   2],
         [  0, 261, 542,  ...,   2,   2,   2],
         ...,
         [  0, 343, 500,  ...,   2,   2,   2],
         [  0, 296, 442,  ...,   2,   2,   2],
         [  0, 296, 317,  ...,   2,   2,   2]]),
 tensor([[  0, 335, 372,  ...,   2,   2,   2],
         [  0, 808, 400,  ...,   2,   2,   2],
         [  0,  67, 504,  ...,   2,   2,   2],
         ...,
         [  0, 335, 479,  ...,   2,   2,   2],
         [  0,  67, 413,  ...,   2,   2,   2],
         [  0,  67, 325,  ...,   2,   2,   2]]))
In [19]:
# create the data loader for both validation and test set
dataset_val = dataset_dict_encoded['val']
dataloader_val = DataLoader(dataset_val, **data_loader_params)

dataset_test = dataset_dict_encoded['test']
dataloader_test = DataLoader(dataset_test, **data_loader_params)

Model Architecture From Scratch

Having prepared the data, we can now start implementing Transformer model's architecture, which looks like the following:

Position Wise Embedding

First, input tokens are passed through a standard embedding layer. Next, as the entire sentence is fed into the model in one go, by default it has no idea about the tokens' order within the sequence. We cope with this by using a second embedding layer, positional embedding. This is an embedding layer where our input is not the token id but the token's position within the sequence. If we configure our position embedding to have a "vocabulary" size of 100, this means our model can accept sentences up to 100 tokens long.

The original Transformer implementation from the Attention is All You Need paper does not learn positional embeddings. Instead it uses a fixed static positional encoding. Modern Transformer architectures, like BERT, use positional embeddings, hence, we have decided to use them in these tutorials. Feel free to check out other tutorials [7] [8] to read more about positional encoding used in the original Transformer model.

Next, token and positional embeddings are combined together using an elementwise sum operation, giving us a single vector that contains information on both the token and its position with in the sequence. Before they are summed, token embeddings are multiplied by a scaling factor $\sqrt{d_{model}}$, where $d_{model}$ is the hidden dimension size, hid_dim. This supposedly reduces variance in the embeddings and without this scaling factor, it becomes difficult to train the model reliably. Dropout is then applied to the combined embeddings.

In [20]:
class PositionWiseEmbedding(nn.Module):

    def __init__(self, input_dim, hid_dim, max_len, dropout_p):
        super().__init__()
        self.input_dim = input_dim
        self.hid_dim = hid_dim
        self.max_len = max_len
        self.dropout_p = dropout_p

        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_len, hid_dim)
        self.dropout = nn.Dropout(dropout_p)
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim]))

    def forward(self, inputs):

        # inputs = [batch size, inputs len]
        batch_size = inputs.shape[0]
        inputs_len = inputs.shape[1]

        pos = torch.arange(0, inputs_len).unsqueeze(0).repeat(batch_size, 1).to(inputs.device)
        scale = self.scale.to(inputs.device)
        embedded = (self.tok_embedding(inputs) * scale) + self.pos_embedding(pos)

        # output = [batch size, inputs len, hid dim]
        output = self.dropout(embedded)
        return output
In [21]:
input_dim = source_tokenizer.get_vocab_size()
hid_dim = 64
max_len = 100
dropout_p = 0.5
embedding = PositionWiseEmbedding(input_dim, hid_dim, max_len, dropout_p).to(device)
embedding
Out[21]:
PositionWiseEmbedding(
  (tok_embedding): Embedding(5000, 64)
  (pos_embedding): Embedding(100, 64)
  (dropout): Dropout(p=0.5, inplace=False)
)
In [22]:
src_embedded = embedding(source.to(device))
src_embedded.shape
Out[22]:
torch.Size([128, 40, 64])

The combined embeddings are then passed through $N$ encoder layers to get our context vectors $Z$. Before jumping straight into the encoder layers, we'll introduce some of the core building blocks behind them.

Multi Head Attention Layer

One of the key concepts introduced by Transformer model is multi-head attention layer.

The purpose behind an attention mechanism is to relate inputs from different parts of the sequence. Attention operation is comprised of queries, keys and values. It might be helpful to look at these terms from an informational retrieval perspective, where every time we issue a query to a search engine, the search engine will match it with some key (title, description), and retrieve the associated value (content).

To be specific, Transformer model uses scaled dot-product attention, where query is used with key to get an attention vector, which is then used to get a weighted sum of the values.

\begin{align} \text{Attention}(Q, K, V) = \text{Softmax} \big( \frac{QK^T}{\sqrt{d_k}} \big)V \end{align}

Where $Q = XW^Q, K = XW^K, V = XW^V$, $X$ is our input matrix, $W^Q$, $W^K$, $W^V$ are linear layers for the query, key and value. $d_k$ is the head dimension, head_dim, which we will further explain shortly. In essence, we are multiplying our input matrix with 3 different weight matrices. We first peform a dot product between query and key followed by a softmax to calculate attention weight, which measures correlation between the two words, finally scaling it by $d_k$ before doing a dot product with the value to get the weighted value. Scaling is done to prevent the results of the dot product from growing too large, and causing the gradients to become too small.

Multi-head attention extends the single attention mechansim so we can potentially pay attention to different concepts that exists at different sequence positions. If end users are familiar with convolutional neural networks, this trick is very similar to introducing multiple filters so each can learn different aspects of the input. Instead of doing a single attention operation, the queries, keys and values have their hid_dim split into $h$ heads each of size $d_k$, and the scaled dot-product attention is calculated over all heads in parallel. After this computation, we re-combine the heads back to hid_dim shape. By reducing the dimensionality of each head/concept, the total computational cost is similar to a full dimension single-head attention.

\begin{align} \text{MultiHead}(Q, K, V) &= \text{Concat}(\text{head}_1,...,\text{head}_h)W^O \\ \text{head}_i &= \text{Attention}(Q_i, K_i, V_i) \end{align}

Where $W^O$ is the linear layer applied at the end of the multi-head attention layer.

In the implementation below, we carry out the multi head attention in parallel using batch matrix multiplication as opposed to a for loop. And while calculating the attention weights, we introduce the capability of applying a mask so the model does not pay attention to irrelevant tokens. We'll elaborate more on this in future sections.

In [23]:
class MultiHeadAttention(nn.Module):

    def __init__(self, hid_dim, n_heads):
        super().__init__()
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        assert hid_dim % n_heads == 0

        self.key_weight = nn.Linear(hid_dim, hid_dim)
        self.query_weight = nn.Linear(hid_dim, hid_dim)
        self.value_weight = nn.Linear(hid_dim, hid_dim)
        self.linear_weight = nn.Linear(hid_dim, hid_dim)

    def forward(self, query, key, value, mask = None):
        batch_size = query.shape[0]
        query_len = query.shape[1]
        key_len = key.shape[1]

        # key/query/value (proj) = [batch size, input len, hid dim]
        key_proj = self.key_weight(key)
        query_proj = self.query_weight(query)
        value_proj = self.value_weight(value)

        # compute the weights between query and key
        query_proj = query_proj.view(batch_size, query_len, self.n_heads, self.head_dim)
        query_proj = query_proj.permute(0, 2, 1, 3)
        key_proj = key_proj.view(batch_size, key_len, self.n_heads, self.head_dim)
        key_proj = key_proj.permute(0, 2, 3, 1)

        # energy, attention = [batch size, num heads, query len, key len]
        energy = torch.matmul(query_proj, key_proj) / math.sqrt(self.head_dim)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)

        attention = torch.softmax(energy, dim=-1)

        # output = [batch size, num heads, query len, head dim]
        value_proj = value_proj.view(batch_size, key_len, self.n_heads, self.head_dim)
        value_proj = value_proj.permute(0, 2, 1, 3)
        output = torch.matmul(attention, value_proj)

        # linaer = [batch size, query len, hid dim]
        output = output.permute(0, 2, 1, 3).contiguous().view(batch_size, query_len, self.hid_dim)
        linear_proj = self.linear_weight(output)
        return linear_proj, attention
In [24]:
n_heads = 8
self_attention = MultiHeadAttention(hid_dim, n_heads).to(device)
self_attention_output, attention = self_attention(src_embedded, src_embedded, src_embedded)
print(self_attention_output.shape)
print(attention.shape)
torch.Size([128, 40, 64])
torch.Size([128, 8, 40, 40])

Position Wise Feed Forward Layer

Another building block is the position wise feed forward layer, which consists of two linear transformations. These transformations are identical across different positions. i.e. feed forward layers are typically used on a tensor of shape (batch_size, hidden_dim), here it is directly operating on a tensor of shape (batch size, seq_len, hidden_dim).

The input is transformed from hid_dim to pf_dim, where pf_dim is usually a lot larger than hid_dim. Then an activation function is applied before it is transformed back into a hid_dim representation.

In [25]:
class PositionWiseFeedForward(nn.Module):

    def __init__(self, hid_dim, pf_dim):
        super().__init__()
        self.hid_dim = hid_dim
        self.pf_dim = pf_dim

        self.fc1 = nn.Linear(hid_dim, pf_dim)
        self.fc2 = nn.Linear(pf_dim, hid_dim)

    def forward(self, inputs):
        # inputs = [batch size, src len, hid dim]
        fc1_output = torch.relu(self.fc1(inputs))
        fc2_output = self.fc2(fc1_output)
        return fc2_output
In [26]:
hid_dim = 64
pf_dim = 256
position_ff = PositionWiseFeedForward(hid_dim, pf_dim).to(device)
position_ff_output = position_ff(self_attention_output)
position_ff_output.shape
Out[26]:
torch.Size([128, 40, 64])

Encoder

We'll now put our building blocks together to form the encoder.

We first pass the source sentence through a position wise embedding layer, this is then followed by N (configurable) encoder layers, the "meat" of modern transformer based architecture. The main role of our encoder is to update our embeddings/weights so that it can learn some contextual information about our text sequence, e.g. the word "bank" will be updated to be more "financial establishment" like and less "land along river" if words such as money and investment are close to it.

Inside the encoder layer, we start from the multi-head attention layer, perform dropout on it, apply a residual connection, pass it through a layer normalization layer. followed by a position-wise feedforward layer and then, again, apply dropout, a residual connection and then layer normalization to get the output, this is then fed into the next layer. This sounds like a mouthful, but potentially the code will clarify things a bit. Things worth noting:

  • Parameters are not shared between layers.
  • Multi head attention layer is used by the encoder layer to attend to the source sentence, i.e. it is calculating and applying attention over itself instead of another sequence, hence we call it self attention. This layer is the only layer that propagates information along the sequence, other layers operate on each individual token in isolation.
  • The gist behind layer normalization is that it normalizes the features' values across the hidden dimension so each feature has a mean of 0 and a standard deviation of 1. This trick along with residual connection, makes it easier to train neural networks with a larger number of layers, like the Transformer.
In [27]:
class EncoderLayer(nn.Module):

    def __init__(self, hid_dim, n_heads, pf_dim, dropout_p):
        super().__init__()
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.pf_dim = pf_dim
        self.dropout_p = dropout_p

        self.self_attention_layer_norm = nn.LayerNorm(hid_dim)
        self.position_ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttention(hid_dim, n_heads)
        self.position_ff = PositionWiseFeedForward(hid_dim, pf_dim)

        self.dropout = nn.Dropout(dropout_p)

    def forward(self, src, src_mask):
        # src = [batch size, src len, hid dim]
        # src_mask = [batch size, 1, 1, src len] 
        self_attention_output, _ = self.self_attention(src, src, src, src_mask)

        # residual connection and layer norm
        self_attention_output = self.dropout(self_attention_output)
        self_attention_output = self.self_attention_layer_norm(src + self_attention_output)

        position_ff_output = self.position_ff(self_attention_output)

        # residual connection and layer norm
        # [batch size, src len, hid dim]
        position_ff_output = self.dropout(position_ff_output)
        output = self.position_ff_layer_norm(self_attention_output + position_ff_output)        
        return output
In [28]:
class Encoder(nn.Module):

    def __init__(self, input_dim, hid_dim, max_len, dropout_p, n_heads, pf_dim, n_layers):
        super().__init__()
        self.input_dim = input_dim
        self.hid_dim = hid_dim
        self.max_len = max_len
        self.dropout_p = dropout_p
        self.n_heads = n_heads
        self.pf_dim = pf_dim
        self.n_layers = n_layers

        self.pos_wise_embedding = PositionWiseEmbedding(input_dim, hid_dim, max_len, dropout_p)
        self.layers = nn.ModuleList([
            EncoderLayer(hid_dim, n_heads, pf_dim, dropout_p)
            for _ in range(n_layers)
        ])

    def forward(self, src, src_mask = None):

        # src = [batch size, src len]
        # src_mask = [batch size, 1, 1, src len]
        src = self.pos_wise_embedding(src)
        for layer in self.layers:
            src = layer(src, src_mask)

        # [batch size, src len, hid dim]
        return src
In [29]:
input_dim = source_tokenizer.get_vocab_size()
hid_dim = 64
max_len = 100
dropout_p = 0.5
n_heads = 8
pf_dim = 256
n_layers = 1
encoder = Encoder(input_dim, hid_dim, max_len, dropout_p, n_heads, pf_dim, n_layers).to(device)
encoder
Out[29]:
Encoder(
  (pos_wise_embedding): PositionWiseEmbedding(
    (tok_embedding): Embedding(5000, 64)
    (pos_embedding): Embedding(100, 64)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (layers): ModuleList(
    (0): EncoderLayer(
      (self_attention_layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (position_ff_layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (self_attention): MultiHeadAttention(
        (key_weight): Linear(in_features=64, out_features=64, bias=True)
        (query_weight): Linear(in_features=64, out_features=64, bias=True)
        (value_weight): Linear(in_features=64, out_features=64, bias=True)
        (linear_weight): Linear(in_features=64, out_features=64, bias=True)
      )
      (position_ff): PositionWiseFeedForward(
        (fc1): Linear(in_features=64, out_features=256, bias=True)
        (fc2): Linear(in_features=256, out_features=64, bias=True)
      )
      (dropout): Dropout(p=0.5, inplace=False)
    )
  )
)
In [30]:
encoder_output = encoder(source.to(device))
encoder_output.shape
Out[30]:
torch.Size([128, 40, 64])

Decoder

Now comes the decoder part:

Decoder's main goal is to take our source sentence's encoded representation, $Z$, convert it into predicted tokens in the target sentence, $\hat{Y}$. We then compare it with the actual tokens in the target sentence, $Y$, to calculate our loss and update our parameters to improve our predictions.

Decoder layer contains similar building blocks as the encoder layer, except it now has two multi-head attention layers, self_attention and encoder_attention.

The former attention layer performs self attention on our target sentence's embedding representation to generate a decoder representation. Whereas for the encoder/decoder attention layer, decoder's intermediate presentation will represent queries, whereas keys and values come from encoder representation's output.

In [31]:
class DecoderLayer(nn.Module):

    def __init__(self, hid_dim, n_heads, pf_dim, dropout_p):
        super().__init__()
        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.pf_dim = pf_dim
        self.dropout_p = dropout_p

        self.self_attention_layer_norm = nn.LayerNorm(hid_dim)
        self.encoder_attention_layer_norm = nn.LayerNorm(hid_dim)
        self.position_ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttention(hid_dim, n_heads)
        self.encoder_attention = MultiHeadAttention(hid_dim, n_heads)
        self.position_ff = PositionWiseFeedForward(hid_dim, pf_dim)
        
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, trg, encoded_src, trg_mask, src_mask):
        # encoded_src = [batch size, src len, hid dim]
        # src_mask = [batch size, 1, 1, src len] 
        self_attention_output, _ = self.self_attention(trg, trg, trg, trg_mask)

        # residual connection and layer norm
        self_attention_output = self.dropout(self_attention_output)
        self_attention_output = self.self_attention_layer_norm(trg + self_attention_output)

        encoder_attention_output, _ = self.encoder_attention(
            self_attention_output,
            encoded_src,
            encoded_src,
            src_mask
        )
        encoder_attention_output = self.dropout(encoder_attention_output)
        encoder_attention_output = self.encoder_attention_layer_norm(trg + encoder_attention_output)

        position_ff_output = self.position_ff(encoder_attention_output)

        # residual connection and layer norm
        # [batch size, src len, hid dim]
        position_ff_output = self.dropout(position_ff_output)
        output = self.position_ff_layer_norm(encoder_attention_output + position_ff_output)        
        return output
In [32]:
class Decoder(nn.Module):

    def __init__(self, output_dim, hid_dim, max_len, dropout_p, n_heads, pf_dim, n_layers):
        super().__init__()
        self.output_dim = output_dim
        self.hid_dim = hid_dim
        self.max_len = max_len
        self.dropout_p = dropout_p
        self.n_heads = n_heads
        self.pf_dim = pf_dim
        self.n_layers = n_layers

        self.pos_wise_embedding = PositionWiseEmbedding(output_dim, hid_dim, max_len, dropout_p)
        self.layers = nn.ModuleList([
            DecoderLayer(hid_dim, n_heads, pf_dim, dropout_p)
            for _ in range(n_layers)
        ])
        self.fc_out = nn.Linear(hid_dim, output_dim)

    def forward(self, trg, encoded_src, trg_mask = None, src_mask = None):

        trg = self.pos_wise_embedding(trg)
        for layer in self.layers:
            trg = layer(trg, encoded_src, trg_mask, src_mask)
        
        output = self.fc_out(trg)
        return output
In [33]:
output_dim = target_tokenizer.get_vocab_size()
hid_dim = 64
max_len = 100
dropout_p = 0.5
n_heads = 8
pf_dim = 32
n_layers = 1
decoder = Decoder(output_dim, hid_dim, max_len, dropout_p, n_heads, pf_dim, n_layers).to(device)
decoder
Out[33]:
Decoder(
  (pos_wise_embedding): PositionWiseEmbedding(
    (tok_embedding): Embedding(5000, 64)
    (pos_embedding): Embedding(100, 64)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (layers): ModuleList(
    (0): DecoderLayer(
      (self_attention_layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (encoder_attention_layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (position_ff_layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (self_attention): MultiHeadAttention(
        (key_weight): Linear(in_features=64, out_features=64, bias=True)
        (query_weight): Linear(in_features=64, out_features=64, bias=True)
        (value_weight): Linear(in_features=64, out_features=64, bias=True)
        (linear_weight): Linear(in_features=64, out_features=64, bias=True)
      )
      (encoder_attention): MultiHeadAttention(
        (key_weight): Linear(in_features=64, out_features=64, bias=True)
        (query_weight): Linear(in_features=64, out_features=64, bias=True)
        (value_weight): Linear(in_features=64, out_features=64, bias=True)
        (linear_weight): Linear(in_features=64, out_features=64, bias=True)
      )
      (position_ff): PositionWiseFeedForward(
        (fc1): Linear(in_features=64, out_features=32, bias=True)
        (fc2): Linear(in_features=32, out_features=64, bias=True)
      )
      (dropout): Dropout(p=0.5, inplace=False)
    )
  )
  (fc_out): Linear(in_features=64, out_features=5000, bias=True)
)
In [34]:
encoder_output = encoder(source.to(device))
decoder_output = decoder(target.to(device), encoder_output)
decoder_output.shape
Out[34]:
torch.Size([128, 26, 5000])

Seq2Seq

Now that we have our encoder and decoder, the final part is to have a Seq2Seq module that encapsulates the two. In this module, we'll also handle masking.

The source mask is created by checking where our source sequence is not equal to the <pad> token. It is 1 where the token is not a token and 0 when it is. This is used in our encoder layers' multi-head attention mechanisms, where we want our model to not pay any attention to <pad> tokens, which contain no useful information.

The target mask is a bit more involved. First, we create a mask for the tokens, as we did for the source mask. Next, we create a "subsequent" mask, trg_sub_mask, using torch.tril. This creates a diagonal matrix where the elements above the diagonal will be zero and the elements below the diagonal will be set to whatever the input tensor is. In this case, the input tensor will be a tensor filled with ones, meaning our trg_sub_mask will look something like this (for a target with 5 tokens):

\begin{matrix} 1 & 0 & 0 & 0 & 0 \\ 1 & 1 & 0 & 0 & 0 \\ 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 \\ 1 & 1 & 1 & 1 & 1 \\ \end{matrix}

This shows what each target token (row) is allowed to look at (column). Our first target token has a mask of [1, 0, 0, 0, 0] which means it can only look at the first target token, whereas the second target token has a mask of [1, 1, 0, 0, 0] which it means it can look at both the first and second target tokens and so on.

The "subsequent" mask is then logically anded with the padding mask, this combines the two masks ensuring both the subsequent tokens and the padding tokens cannot be attended to. For example if the last two tokens were <pad> tokens the final target mask would look like:

\begin{matrix} 1 & 0 & 0 & 0 & 0 \\ 1 & 1 & 0 & 0 & 0 \\ 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 0 & 0 \\ \end{matrix}

These masks are fed in into model along with source and target sentence to get out predicted target output.

Site Note: Introducing some other terminology that we might come across. The need to create a subsequent mask is very common in autoregressive model, where the task is to predict the next token in the sequence (e.g. language model). By introducing this masking, we are making the self attention block casual. Different implementation or library might have different ways of specifying this masking, but the core idea is to prevent the model from "cheating" by copying the tokens that are after the ones it's currently processing.

In [35]:
class Seq2Seq(nn.Module):

    def __init__(self, encoder, decoder, src_pad_idx, trg_pad_idx):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx

    def make_src_mask(self, src):
        """
        the padding mask is unsqueezed so it can be correctly broadcasted
        when applying the mask to the attention weights, which is of shape
        [batch size, n heads, seq len, seq len].
        """
        src_pad_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        return src_pad_mask

    def make_trg_mask(self, trg):
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)

        trg_len = trg.shape[1]
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len))).bool().to(trg.device)
        trg_mask = trg_pad_mask & trg_sub_mask
        return trg_mask

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)

        encoded_src = self.encoder(src, src_mask)
        decoder_output = self.decoder(trg, encoded_src, trg_mask, src_mask)
        return decoder_output
In [36]:
source_pad_idx = source_tokenizer.token_to_id(pad_token)
target_pad_idx = target_tokenizer.token_to_id(pad_token)

INPUT_DIM = source_tokenizer.get_vocab_size()
OUTPUT_DIM = target_tokenizer.get_vocab_size()
MAX_LEN = 100
HID_DIM = 512
ENC_LAYERS = 6
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1

encoder = Encoder(
    INPUT_DIM, 
    HID_DIM,
    MAX_LEN,
    ENC_DROPOUT, 
    ENC_HEADS, 
    ENC_PF_DIM, 
    ENC_LAYERS
)

decoder = Decoder(
    OUTPUT_DIM, 
    HID_DIM,
    MAX_LEN,
    DEC_DROPOUT,
    DEC_HEADS,
    DEC_PF_DIM,
    DEC_LAYERS
)

model = Seq2Seq(encoder, decoder, source_pad_idx, target_pad_idx).to(device)
model
Out[36]:
Seq2Seq(
  (encoder): Encoder(
    (pos_wise_embedding): PositionWiseEmbedding(
      (tok_embedding): Embedding(5000, 512)
      (pos_embedding): Embedding(100, 512)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): ModuleList(
      (0-5): 6 x EncoderLayer(
        (self_attention_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (position_ff_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (self_attention): MultiHeadAttention(
          (key_weight): Linear(in_features=512, out_features=512, bias=True)
          (query_weight): Linear(in_features=512, out_features=512, bias=True)
          (value_weight): Linear(in_features=512, out_features=512, bias=True)
          (linear_weight): Linear(in_features=512, out_features=512, bias=True)
        )
        (position_ff): PositionWiseFeedForward(
          (fc1): Linear(in_features=512, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=512, bias=True)
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (decoder): Decoder(
    (pos_wise_embedding): PositionWiseEmbedding(
      (tok_embedding): Embedding(5000, 512)
      (pos_embedding): Embedding(100, 512)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): ModuleList(
      (0-2): 3 x DecoderLayer(
        (self_attention_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (encoder_attention_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (position_ff_layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (self_attention): MultiHeadAttention(
          (key_weight): Linear(in_features=512, out_features=512, bias=True)
          (query_weight): Linear(in_features=512, out_features=512, bias=True)
          (value_weight): Linear(in_features=512, out_features=512, bias=True)
          (linear_weight): Linear(in_features=512, out_features=512, bias=True)
        )
        (encoder_attention): MultiHeadAttention(
          (key_weight): Linear(in_features=512, out_features=512, bias=True)
          (query_weight): Linear(in_features=512, out_features=512, bias=True)
          (value_weight): Linear(in_features=512, out_features=512, bias=True)
          (linear_weight): Linear(in_features=512, out_features=512, bias=True)
        )
        (position_ff): PositionWiseFeedForward(
          (fc1): Linear(in_features=512, out_features=512, bias=True)
          (fc2): Linear(in_features=512, out_features=512, bias=True)
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (fc_out): Linear(in_features=512, out_features=5000, bias=True)
  )
)
In [37]:
output = model(source.to(device), target.to(device))
output.shape
Out[37]:
torch.Size([128, 26, 5000])
In [38]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')
The model has 25,144,200 trainable parameters

Model Training

The training loop also requires a bit of explanation.

We want our model to predict the <eos> token but not have it be an input into our model, hence we slice the <eos> token off the end of our target sequence.

\begin{align} \text{trg} &= [sos, x_1, x_2, x_3, eos] \\ \text{trg[:-1]} &= [sos, x_1, x_2, x_3] \end{align}

We then calculate our loss using the original target tensor with the <sos> token sliced off the front, retaining the <eos> token.

\begin{align} \text{output} &= [y_1, y_2, y_3, eos] \\ \text{trg[1:]} &= [x_1, x_2, x_3, eos] \end{align}

All in all, our model receives the target up to the last character (excluding the last), whereas the ground truth will be from the second character onward.

In [39]:
def train(model, iterator, optimizer, criterion, clip):
    
    model.train()
    epoch_loss = 0
    for i, (src, trg) in enumerate(iterator):
        src = src.to(device)
        trg = trg.to(device)

        optimizer.zero_grad()
        output = model(src, trg[:, :-1])
                
        # output = [batch size, trg len - 1, output dim]
        # trg = [batch size, trg len]
        output_dim = output.shape[-1]
        output = output.contiguous().view(-1, output_dim)
        trg = trg[:, 1:].contiguous().view(-1)

        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

Evaluation loop is similar to the training loop, just without the updating the model's parameters.

In [40]:
def evaluate(model, iterator, criterion):
    
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, (src, trg) in enumerate(iterator):
            src = src.to(device)
            trg = trg.to(device)

            output = model(src, trg[:, :-1])
            
            # output = [batch size, trg len - 1, output dim]
            # trg = [batch size, trg len]
            output_dim = output.shape[-1]
            output = output.contiguous().view(-1, output_dim)
            trg = trg[:, 1:].contiguous().view(-1)

            loss = criterion(output, trg)
            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)
In [41]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

While defining our loss function, we also ensure we ignore loss that are calculated over the <pad> tokens.

In [42]:
MODEL_CHECKPOINT = 'transformer.pt'
N_EPOCHS = 10
CLIP = 1
LEARNING_RATE = 0.0001
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index=target_pad_idx)
In [43]:
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
    start_time = time.time()
    train_loss = train(model, dataloader_train, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, dataloader_val, criterion)
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), MODEL_CHECKPOINT)

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')
Epoch: 01 | Time: 0m 16s
	Train Loss: 4.926 | Train PPL: 137.785
	 Val. Loss: 4.057 |  Val. PPL:  57.784
Epoch: 02 | Time: 0m 16s
	Train Loss: 3.808 | Train PPL:  45.049
	 Val. Loss: 3.566 |  Val. PPL:  35.371
Epoch: 03 | Time: 0m 16s
	Train Loss: 3.422 | Train PPL:  30.628
	 Val. Loss: 3.290 |  Val. PPL:  26.830
Epoch: 04 | Time: 0m 16s
	Train Loss: 3.156 | Train PPL:  23.484
	 Val. Loss: 3.078 |  Val. PPL:  21.713
Epoch: 05 | Time: 0m 16s
	Train Loss: 2.945 | Train PPL:  19.018
	 Val. Loss: 2.918 |  Val. PPL:  18.507
Epoch: 06 | Time: 0m 16s
	Train Loss: 2.766 | Train PPL:  15.891
	 Val. Loss: 2.789 |  Val. PPL:  16.257
Epoch: 07 | Time: 0m 17s
	Train Loss: 2.606 | Train PPL:  13.544
	 Val. Loss: 2.681 |  Val. PPL:  14.599
Epoch: 08 | Time: 0m 16s
	Train Loss: 2.463 | Train PPL:  11.738
	 Val. Loss: 2.585 |  Val. PPL:  13.259
Epoch: 09 | Time: 0m 17s
	Train Loss: 2.336 | Train PPL:  10.340
	 Val. Loss: 2.515 |  Val. PPL:  12.368
Epoch: 10 | Time: 0m 17s
	Train Loss: 2.216 | Train PPL:   9.173
	 Val. Loss: 2.448 |  Val. PPL:  11.561

Model Evaluation

In [44]:
model.load_state_dict(torch.load(MODEL_CHECKPOINT))
test_loss = evaluate(model, dataloader_test, criterion)
print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')
| Test Loss: 2.429 | Test PPL:  11.352 |
In [45]:
def predict(source, model, source_tokenizer, target_tokenizer):
    """
    Given the raw token, predict the translation using greedy search.
    This is a naive implementation without batching
    """
    src_indices = [source_init_idx] + source_tokenizer.encode(source).ids + [source_eos_idx]
    src_tensor = torch.LongTensor(src_indices).unsqueeze(0).to(device)
    src_mask = model.make_src_mask(src_tensor)

    # separating out the encoder and decoder allows us to generate the encoded source
    # sentence once and share it throughout the target prediction step
    with torch.no_grad():
        encoded_src = model.encoder(src_tensor, src_mask)

    # greedy search
    # sequentially predict the target sequence starting from the init sentence token
    trg_indices = [target_init_idx]
    for _ in range(max_len):
        trg_tensor = torch.LongTensor(trg_indices).unsqueeze(0).to(device)
        trg_mask = model.make_trg_mask(trg_tensor)

        with torch.no_grad():
            output = model.decoder(trg_tensor, encoded_src, trg_mask, src_mask)

        # add the last predicted token
        pred_token = output.argmax(dim=2)[:, -1].item()
        trg_indices.append(pred_token)
        if pred_token == target_eos_idx:
            break

    return target_tokenizer.decode(trg_indices)
In [46]:
translation = dataset_dict['train'][0]
source_raw = translation[source_lang]
target_raw = translation[target_lang]
print('source: ', source_raw)
print('target: ', target_raw)

predict(source_raw, model, source_tokenizer, target_tokenizer)
source:  Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.
target:  Two young, White males are outside near many bushes.
Out[46]:
'two young men are outside near white.'

Transformer Module

Instead of resorting to our own Transformer encoder and decoder implementation, PyTorch's nn module already comes with a pre-built one. The major difference here is it expects a different shape for the padding and subsequent mask.

In [47]:
class Transformer(nn.Module):
    """
    
    References
    ----------
    https://pytorch.org/docs/master/generated/torch.nn.Transformer.html
    """

    def __init__(
        self,
        encoder_embedding_dim,
        decoder_embedding_dim,
        source_pad_idx,
        target_pad_idx,
        encoder_max_len = 100,
        decoder_max_len = 100,
        model_dim = 512,
        num_head = 8,
        encoder_num_layers = 3,
        decoder_num_layers = 3,
        feedforward_dim = 512,
        dropout = 0.1
    ):
        super().__init__()
        self.source_pad_idx = source_pad_idx
        self.target_pad_idx = target_pad_idx

        self.encoder_embedding = PositionWiseEmbedding(
            encoder_embedding_dim,
            model_dim,
            encoder_max_len,
            dropout
        )
        self.decoder_embedding = PositionWiseEmbedding(
            decoder_embedding_dim,
            model_dim,
            decoder_max_len,
            dropout
        )

        layer_params = {
            'd_model': model_dim,
            'nhead': num_head,
            'dim_feedforward': feedforward_dim,
            'dropout': dropout
        }
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(**layer_params),
            encoder_num_layers
        )
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(**layer_params),
            decoder_num_layers
        )
        self.linear = nn.Linear(model_dim, decoder_embedding_dim)

    def forward(self, src_tensor, trg_tensor):
        # enc_src = self.encoder(src, src_mask)
        # decoder_output = self.decoder(trg, enc_src, trg_mask, src_mask)

        # in PyTorch's Transformer Encoder and Decoder implementation, they
        # expect the first dimension to be batch size
        src_encoded = self.encode(src_tensor)
        output = self.decode(trg_tensor, src_encoded)
        return output

    def encode(self, src_tensor):
        src_key_padding_mask = generate_key_padding_mask(src_tensor, self.source_pad_idx)
        src_embedded = self.encoder_embedding(src_tensor).permute(1, 0, 2)
        return self.encoder(src_embedded, src_key_padding_mask=src_key_padding_mask)

    def decode(self, trg_tensor, src_encoded):
        trg_key_padding_mask = generate_key_padding_mask(trg_tensor, self.target_pad_idx)
        trg_mask = generate_subsequent_mask(trg_tensor)
        trg_embedded = self.decoder_embedding(trg_tensor).permute(1, 0, 2)
        decoder_output = self.decoder(
            trg_embedded,
            src_encoded,
            trg_mask,
            tgt_key_padding_mask=trg_key_padding_mask
        ).permute(1, 0, 2)
        return self.linear(decoder_output)

    def predict(self, src_tensor, max_len = 100):
        # separating out the encoder and decoder allows us to generate the encoded source
        # sentence once and share it throughout the target prediction step
        with torch.no_grad():
            src_encoded = self.encode(src_tensor)
  
        # greedy search
        # sequentially predict the target sequence starting from the init sentence token
        trg_indices = [target_init_idx]
        for _ in range(max_len):
            trg_tensor = torch.LongTensor(trg_indices).unsqueeze(0).to(src_tensor.device)
            with torch.no_grad():
                output = self.decode(trg_tensor, src_encoded)

            # add the last predicted token
            pred_token = output.argmax(dim=2)[:, -1].item()
            trg_indices.append(pred_token)
            if pred_token == target_eos_idx:
                break

        return trg_indices


def generate_subsequent_mask(inputs):
    """
    If a BoolTensor is provided, positions with True are not
    allowed to attend while False values will be unchanged
    """
    inputs_len = inputs.shape[1]
    square = torch.ones((inputs_len, inputs_len)).to(inputs.device)
    mask = (torch.tril(square) == 0.0).bool()
    return mask


def generate_key_padding_mask(inputs, pad_idx):
    return (inputs == pad_idx).to(inputs.device)
In [48]:
INPUT_DIM = source_tokenizer.get_vocab_size()
OUTPUT_DIM = target_tokenizer.get_vocab_size()

transformer = Transformer(INPUT_DIM, OUTPUT_DIM, source_pad_idx, target_pad_idx).to(device)

with torch.no_grad():
    output = transformer(source.to(device), target.to(device))

output.shape
Out[48]:
torch.Size([128, 26, 5000])
In [49]:
print(f'The model has {count_parameters(transformer):,} trainable parameters')
The model has 20,410,248 trainable parameters
In [50]:
MODEL_CHECKPOINT = 'transformer.pt'
N_EPOCHS = 10
CLIP = 1
LEARNING_RATE = 0.0005
optimizer = optim.Adam(transformer.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index=target_pad_idx)
In [51]:
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
    start_time = time.time()
    train_loss = train(transformer, dataloader_train, optimizer, criterion, CLIP)
    valid_loss = evaluate(transformer, dataloader_val, criterion)
    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(transformer.state_dict(), MODEL_CHECKPOINT)

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')
Epoch: 01 | Time: 0m 13s
	Train Loss: 4.061 | Train PPL:  58.036
	 Val. Loss: 3.313 |  Val. PPL:  27.469
Epoch: 02 | Time: 0m 13s
	Train Loss: 3.010 | Train PPL:  20.289
	 Val. Loss: 2.769 |  Val. PPL:  15.946
Epoch: 03 | Time: 0m 13s
	Train Loss: 2.530 | Train PPL:  12.555
	 Val. Loss: 2.478 |  Val. PPL:  11.921
Epoch: 04 | Time: 0m 13s
	Train Loss: 2.180 | Train PPL:   8.848
	 Val. Loss: 2.319 |  Val. PPL:  10.167
Epoch: 05 | Time: 0m 13s
	Train Loss: 1.922 | Train PPL:   6.834
	 Val. Loss: 2.213 |  Val. PPL:   9.140
Epoch: 06 | Time: 0m 13s
	Train Loss: 1.703 | Train PPL:   5.489
	 Val. Loss: 2.163 |  Val. PPL:   8.697
Epoch: 07 | Time: 0m 13s
	Train Loss: 1.527 | Train PPL:   4.602
	 Val. Loss: 2.148 |  Val. PPL:   8.570
Epoch: 08 | Time: 0m 13s
	Train Loss: 1.361 | Train PPL:   3.898
	 Val. Loss: 2.139 |  Val. PPL:   8.488
Epoch: 09 | Time: 0m 13s
	Train Loss: 1.235 | Train PPL:   3.438
	 Val. Loss: 2.143 |  Val. PPL:   8.526
Epoch: 10 | Time: 0m 13s
	Train Loss: 1.128 | Train PPL:   3.089
	 Val. Loss: 2.192 |  Val. PPL:   8.957
In [52]:
transformer.load_state_dict(torch.load(MODEL_CHECKPOINT))
test_loss = evaluate(transformer, dataloader_test, criterion)
print(f'| Test Loss: {test_loss:.3f} | Test PPL: {math.exp(test_loss):7.3f} |')
| Test Loss: 2.134 | Test PPL:   8.445 |
In [53]:
def transformer_predict(source, model, source_tokenizer, target_tokenizer):
    src_indices = [source_init_idx] + source_tokenizer.encode(source).ids + [source_eos_idx]
    src_tensor = torch.LongTensor(src_indices).unsqueeze(0).to(device)
    trg_indices = model.predict(src_tensor)
    return target_tokenizer.decode(trg_indices)
In [54]:
translation = dataset_dict['train'][0]
source_raw = translation[source_lang]
target_raw = translation[target_lang]
print('source: ', source_raw)
print('target: ', target_raw)

transformer_predict(source_raw, transformer, source_tokenizer, target_tokenizer)
source:  Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.
target:  Two young, White males are outside near many bushes.
Out[54]:
'two young men are outside, one in white, are in the doorway.'

In this notebook, we delved into the implementation of Transformer models. Although originally proposed for solving NLP tasks like machine translation, this module or building block is also gaining popularity in other fields such as computer vision [2].

Reference

  • [1] Jupyter Notebook: Attention is All You Need
  • [2] Jupyter Notebook: Tutorial 6: Transformers and Multi-Head Attention
  • [3] Colab: Simple PyTorch Transformer Example with Greedy Decoding
  • [4] Blog: Transformers from scratch
  • [5] Blog: Making Pytorch Transformer Twice as Fast on Sequence Generation
  • [6] Blog: How Transformers work in deep learning and NLP: an intuitive introduction
  • [7] PyTorch Documentation: Sequence to sequence modeling with nn.Transformer and Torchtext
  • [8] The Annotated Transformer
  • [9] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin - Attention is All you Need (2017)