In [1]:
# code for loading the format for the notebook
import os

# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', 'notebook_format'))

from formats import load_style
load_style(plot_style=False)
Out[1]:
In [2]:
os.chdir(path)

# 1. magic for inline plot
# 2. magic to print version
# 3. magic so that the notebook will reload external python modules
# 4. magic to enable retina (high resolution) plots
# https://gist.github.com/minrk/3301035
%matplotlib inline
%load_ext watermark
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# change default style figure and font size
plt.rcParams['figure.figsize'] = 8, 6
plt.rcParams['font.size'] = 12

%watermark -a 'Ethen' -d -t -v -p numpy,pandas,sklearn,matplotlib,xgboost,keras
Using TensorFlow backend.
Ethen 2019-06-14 09:48:45 

CPython 3.6.4
IPython 7.5.0

numpy 1.16.1
pandas 0.24.2
sklearn 0.20.3
matplotlib 3.0.3
xgboost 0.81
keras 2.2.2

Leveraging Word2vec for Text Classification

Many machine learning algorithms requires the input features to be represented as a fixed-length feature vector. When it comes to texts, one of the most common fixed-length features is one hot encoding methods such as bag of words or tf-idf. The advantage of these approach is that they have fast execution time, while the main drawback is they lose the ordering & semantics of the words.

The motivation behind converting text into semantic vectors (such as the ones provided by Word2Vec) is that not only do these type of methods have the capabilities to extract the semantic relationships (e.g. the word powerful should be closely related to strong as oppose to another word like bank), but they should be preserve most of the relevant information about a text while having relatively low dimensionality.

In this notebook, we'll take a look at how a Word2Vec model can also be used as a dimensionality reduction algorithm to feed into a text classifier. A good one should be able to extract the signal from the noise efficiently, hence improving the performance of the classifier.

Data Preparation

We'll download the text classification data, read it into a pandas dataframe and split it into train and test set.

In [3]:
import os
from subprocess import call


def download_data(base_dir='.'):
    """download Reuters' text categorization benchmarks from its url."""
    
    train_data = 'r8-train-no-stop.txt'
    test_data = 'r8-test-no-stop.txt'
    concat_data = 'r8-no-stop.txt'
    base_url = 'http://www.cs.umb.edu/~smimarog/textmining/datasets/'
    
    if not os.path.isdir(base_dir):
        os.makedirs(base_dir, exist_ok=True)
        
    dir_prefix_flag = ' --directory-prefix ' + base_dir

    # brew install wget
    # on a mac if you don't have it
    train_data_path = os.path.join(base_dir, train_data)
    if not os.path.isfile(train_data_path):
        call('wget ' + base_url + train_data + dir_prefix_flag, shell=True)
    
    test_data_path = os.path.join(base_dir, test_data)
    if not os.path.isfile(test_data_path):
        call('wget ' + base_url + test_data + dir_prefix_flag, shell=True)

    concat_data_path = os.path.join(base_dir, concat_data)
    if not os.path.isfile(concat_data_path):
        # concatenate train and test files, we'll make our own train-test splits
        # the > piping symbol directs the concatenated file to a new file, it
        # will replace the file if it already exists; on the other hand, the >> symbol
        # will append if it already exists
        train_test_path = os.path.join(base_dir, 'r8-*-no-stop.txt')
        call('cat {} > {}'.format(train_test_path, concat_data_path), shell=True)

    return concat_data_path
In [4]:
base_dir = 'data'
data_path = download_data(base_dir)
data_path
Out[4]:
'data/r8-no-stop.txt'
In [5]:
def load_data(data_path):
    texts, labels = [], []
    with open(data_path) as f:
        for line in f:
            label, text = line.split('\t')
            # texts are already tokenized, just split on space
            # in a real use-case we would put more effort in preprocessing
            texts.append(text.split())
            labels.append(label)
            
    return pd.DataFrame({'texts': texts, 'labels': labels})
In [6]:
data = load_data(data_path)
data['labels'] = data['labels'].astype('category')
print('dimension: ', data.shape)
data.head()
dimension:  (7674, 2)
Out[6]:
texts labels
0 [asian, exporters, fear, damage, japan, rift, ... trade
1 [china, daily, vermin, eat, pct, grain, stocks... grain
2 [australian, foreign, ship, ban, ends, nsw, po... ship
3 [sumitomo, bank, aims, quick, recovery, merger... acq
4 [amatil, proposes, two, for, bonus, share, iss... earn
In [7]:
label_mapping = data['labels'].cat.categories
data['labels'] = data['labels'].cat.codes
X = data['texts']
y = data['labels']
In [8]:
test_size = 0.1
random_state = 1234

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=test_size, random_state=random_state, stratify=y)

# val_size = 0.1
# X_train, X_val, y_train, y_val = train_test_split(
#     X_train, y_train, test_size=val_size, random_state=random_state, stratify=y_train)

Gensim Implementation

After feeding the Word2Vec algorithm with our corpus, it will learn a vector representation for each word. This by itself, however, is still not enough to be used as features for text classification as each record in our data is a document not a word.

To extend these word vectors and generate document level vectors, we'll take the naive approach and use an average of all the words in the document (We could also leverage tf-idf to generate a weighted-average version, but that is not done here). The Word2Vec algorithm is wrapped inside a sklearn-compatible transformer which can be used almost the same way as CountVectorizer or TfidfVectorizer from sklearn.feature_extraction.text. Almost - because sklearn vectorizers can also do their own tokenization - a feature which we won't be using anyway because the corpus we will be using is already tokenized.

In the next few code chunks, we will build a pipeline that transforms the text into low dimensional vectors via average word vectors as use it to fit a boosted tree model, we then report the performance of the training/test set.

The transformers folder that contains the implementation is at the following link.

In [9]:
from xgboost import XGBClassifier
from sklearn.pipeline import Pipeline
from transformers import GensimWord2VecVectorizer

gensim_word2vec_tr = GensimWord2VecVectorizer(size=50, min_count=3, sg=1, alpha=0.025, iter=10)
xgb = XGBClassifier(learning_rate=0.01, n_estimators=100, n_jobs=-1)
w2v_xgb = Pipeline([
    ('w2v', gensim_word2vec_tr), 
    ('xgb', xgb)
])
w2v_xgb
Out[9]:
Pipeline(memory=None,
     steps=[('w2v', GensimWord2VecVectorizer(alpha=0.025, batch_words=10000, callbacks=(),
             cbow_mean=1, compute_loss=False,
             hashfxn=<built-in function hash>, hs=0, iter=10,
             max_final_vocab=None, max_vocab_size=None, min_alpha=0.0001,
             min_count=3, negati...tate=0, reg_alpha=0, reg_lambda=1, scale_pos_weight=1,
       seed=None, silent=True, subsample=1))])
In [10]:
import time

start = time.time()
w2v_xgb.fit(X_train, y_train)
elapse = time.time() - start
print('elapsed: ', elapse)
w2v_xgb
elapsed:  11.784907102584839
Out[10]:
Pipeline(memory=None,
     steps=[('w2v', GensimWord2VecVectorizer(alpha=0.025, batch_words=10000, callbacks=(),
             cbow_mean=1, compute_loss=False,
             hashfxn=<built-in function hash>, hs=0, iter=10,
             max_final_vocab=None, max_vocab_size=None, min_alpha=0.0001,
             min_count=3, negati...
       reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,
       silent=True, subsample=1))])
In [11]:
from sklearn.metrics import accuracy_score, confusion_matrix

y_train_pred = w2v_xgb.predict(X_train)
print('Training set accuracy %s' % accuracy_score(y_train, y_train_pred))
confusion_matrix(y_train, y_train_pred)
Training set accuracy 0.9485954242687518
Out[11]:
array([[2026,    2,   29,    0,    0,    3,    0,    3],
       [  12,  313,    3,    0,    0,    0,    7,    1],
       [ 148,    0, 3381,    0,    0,    1,    0,    0],
       [   3,    1,    6,   23,    1,    0,    4,    8],
       [   5,    0,    3,    0,  205,   28,    1,    2],
       [   5,    1,    7,    0,    9,  238,    0,    4],
       [  16,    6,    5,    0,    2,    1,   96,    4],
       [   3,    1,    7,    0,    7,    6,    0,  269]])
In [12]:
y_test_pred = w2v_xgb.predict(X_test)
print('Test set accuracy %s' % accuracy_score(y_test, y_test_pred))
confusion_matrix(y_test, y_test_pred)
Test set accuracy 0.9244791666666666
Out[12]:
array([[224,   0,   3,   0,   0,   1,   0,   1],
       [  2,  34,   0,   0,   0,   0,   1,   1],
       [ 17,   0, 375,   0,   0,   0,   1,   0],
       [  0,   0,   0,   1,   0,   0,   1,   3],
       [  1,   1,   0,   0,  19,   5,   0,   1],
       [  2,   0,   1,   0,   3,  21,   0,   2],
       [  1,   1,   2,   0,   0,   0,   8,   2],
       [  0,   2,   0,   0,   0,   3,   0,  28]])

We can extract the Word2vec part of the pipeline and do some sanity check of whether the word vectors that were learned made any sense.

In [13]:
vocab_size = len(w2v_xgb.named_steps['w2v'].model_.wv.index2word)
print('vocabulary size:', vocab_size)
w2v_xgb.named_steps['w2v'].model_.wv.most_similar(positive=['stock'])
vocabulary size: 9846
Out[13]:
[('shares', 0.8288736939430237),
 ('common', 0.8102667927742004),
 ('dilutive', 0.7489669919013977),
 ('effected', 0.7259572744369507),
 ('warrants', 0.718142032623291),
 ('lazo', 0.7158320546150208),
 ('pubco', 0.7046886682510376),
 ('dealings', 0.7039074897766113),
 ('fractional', 0.7031500339508057),
 ('spartech', 0.7023240327835083)]

Keras Implementation

We'll also show how we can use a generic deep learning framework to implement the Wor2Vec part of the pipeline. There are many variants of Wor2Vec, here, we'll only be implementing skip-gram and negative sampling.

The flow would look like the following:

An (integer) input of a target word and a real or negative context word. This is essentially the skipgram part where any word within the context of the target word is a real context word and we randomly draw from the rest of the vocabulary to serve as the negative context words.

An embedding layer lookup (i.e. looking up the integer index of the word in the embedding matrix to get the word vector).

A dot product operation. As the network trains, words which are similar should end up having similar embedding vectors. The most popular way of measuring similarity between two vectors $A$ and $B$ is the cosine similarity.

\begin{align} similarity = cos(\theta) = \frac{\textbf{A}\cdot\textbf{B}}{\parallel\textbf{A}\parallel_2 \parallel \textbf{B} \parallel_2} \end{align}

The denominator of this measure acts to normalize the result – the real similarity operation is on the numerator: the dot product between vectors $A$ and $B$.

Followed by a sigmoid output layer. Our network is a binary classifier since it's distinguishing words from the same context versus those that aren't.

In [14]:
# the keras model/graph would look something like this:
from keras import layers, optimizers, Model

# adjustable parameter that control the dimension of the word vectors
embed_size = 100

input_center = layers.Input((1,))
input_context = layers.Input((1,))

embedding = layers.Embedding(vocab_size, embed_size, input_length=1, name='embed_in')
center = embedding(input_center)  # shape [seq_len, # features (1), embed_size]
context = embedding(input_context)

center = layers.Reshape((embed_size,))(center)
context = layers.Reshape((embed_size,))(context)

dot_product = layers.dot([center, context], axes=1)
output = layers.Dense(1, activation='sigmoid')(dot_product)
model = Model(inputs=[input_center, input_context], outputs=output)
model.compile(loss='binary_crossentropy', optimizer=optimizers.RMSprop(lr=0.01))
model.summary()
WARNING:tensorflow:From /Users/mingyuliu/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 1)            0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, 1)            0                                            
__________________________________________________________________________________________________
embed_in (Embedding)            (None, 1, 100)       984600      input_1[0][0]                    
                                                                 input_2[0][0]                    
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 100)          0           embed_in[0][0]                   
__________________________________________________________________________________________________
reshape_2 (Reshape)             (None, 100)          0           embed_in[1][0]                   
__________________________________________________________________________________________________
dot_1 (Dot)                     (None, 1)            0           reshape_1[0][0]                  
                                                                 reshape_2[0][0]                  
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 1)            2           dot_1[0][0]                      
==================================================================================================
Total params: 984,602
Trainable params: 984,602
Non-trainable params: 0
__________________________________________________________________________________________________
In [15]:
# then we can feed in the skipgram and its label (whether the word pair is in or outside
# the context)
batch_center = [2354, 2354, 2354, 69, 69]
batch_context = [4288, 203, 69, 2535, 815]
batch_label = [0, 1, 1, 0, 1]
model.train_on_batch([batch_center, batch_context], batch_label)
WARNING:tensorflow:From /Users/mingyuliu/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
Out[15]:
0.6951684

The transformers folder that contains the implementation is at the following link.

In [16]:
from transformers import KerasWord2VecVectorizer

keras_word2vec_tr = KerasWord2VecVectorizer(embed_size=50, min_count=3, epochs=5000,
                                            negative_samples=2)
keras_word2vec_tr
Out[16]:
KerasWord2VecVectorizer(batch_size=64, embed_size=50, epochs=5000,
            learning_rate=0.05, min_count=3, negative_samples=2,
            sort_vocab=True, use_sampling_table=True, window_size=5)
In [17]:
keras_w2v_xgb = Pipeline([
    ('w2v', keras_word2vec_tr), 
    ('xgb', xgb)
])

keras_w2v_xgb.fit(X_train, y_train)
100%|██████████| 5000/5000 [02:49<00:00, 29.45it/s]
Out[17]:
Pipeline(memory=None,
     steps=[('w2v', KerasWord2VecVectorizer(batch_size=64, embed_size=50, epochs=5000,
            learning_rate=0.05, min_count=3, negative_samples=2,
            sort_vocab=True, use_sampling_table=True, window_size=5)), ('xgb', XGBClassifier(base_score=0.5, booster='gbtree', colsample_bylevel=1,
     ...
       reg_alpha=0, reg_lambda=1, scale_pos_weight=1, seed=None,
       silent=True, subsample=1))])
In [18]:
y_train_pred = keras_w2v_xgb.predict(X_train)
print('Training set accuracy %s' % accuracy_score(y_train, y_train_pred))
confusion_matrix(y_train, y_train_pred)
Training set accuracy 0.9093541847668694
Out[18]:
array([[1966,    4,   69,    0,    6,    6,    2,   10],
       [  57,  255,   14,    0,    1,    1,    5,    3],
       [ 143,    4, 3377,    0,    2,    0,    2,    2],
       [  10,    2,    3,   16,    0,    2,    8,    5],
       [   5,    2,    5,    0,  209,   16,    0,    7],
       [  19,    6,   13,    0,   26,  177,    0,   23],
       [  43,    6,    7,    4,    0,    1,   65,    4],
       [  10,   14,   39,    0,    5,    7,    3,  215]])
In [19]:
y_test_pred = keras_w2v_xgb.predict(X_test)
print('Test set accuracy %s' % accuracy_score(y_test, y_test_pred))
confusion_matrix(y_test, y_test_pred)
Test set accuracy 0.8958333333333334
Out[19]:
array([[218,   1,   6,   1,   0,   0,   1,   2],
       [  7,  28,   1,   0,   0,   0,   1,   1],
       [ 20,   1, 371,   0,   1,   0,   0,   0],
       [  1,   1,   1,   0,   0,   0,   1,   1],
       [  0,   0,   4,   0,  18,   4,   0,   1],
       [  2,   0,   1,   0,   4,  20,   0,   2],
       [  2,   2,   1,   0,   0,   1,   7,   1],
       [  2,   0,   2,   0,   2,   1,   0,  26]])
In [20]:
print('vocabulary size:', keras_w2v_xgb.named_steps['w2v'].vocab_size_)
keras_w2v_xgb.named_steps['w2v'].most_similar(positive=['stock'])
vocabulary size: 9847
Out[20]:
[('common', 0.7450751),
 ('split', 0.72033733),
 ('annual', 0.7107709),
 ('acquistion', 0.7084387),
 ('remaining', 0.6948875),
 ('total', 0.69227046),
 ('shareholder', 0.68726707),
 ('shareholders', 0.6678084),
 ('associates', 0.64986753),
 ('board', 0.6477556)]

Benchmarks

We'll compare the word2vec + xgboost approach with tfidf + logistic regression. The latter approach is known for its interpretability and fast training time, hence serves as a strong baseline.

Note that for sklearn's tfidf, we didn't use the default analyzer 'words', as this means it expects that input is a single string which it will try to split into individual words, but our texts are already tokenized, i.e. already lists of words.

In [21]:
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import TfidfVectorizer
 
tfidf = TfidfVectorizer(stop_words='english', analyzer=lambda x: x)
logistic = LogisticRegression(solver='liblinear', multi_class='auto')

tfidf_logistic = Pipeline([
    ('tfidf', tfidf), 
    ('logistic', logistic)
])
In [22]:
from scipy.stats import randint, uniform

w2v_params = {'w2v__size': [100, 150, 200]}
tfidf_params = {'tfidf__ngram_range': [(1, 1), (1, 2)]}
logistic_params = {'logistic__C': [0.5, 1.0, 1.5]}
xgb_params = {'xgb__max_depth': randint(low=3, high=12),
              'xgb__colsample_bytree': uniform(loc=0.8, scale=0.2),
              'xgb__subsample': uniform(loc=0.8, scale=0.2)}

tfidf_logistic_params = {**tfidf_params, **logistic_params}
w2v_xgb_params = {**w2v_params, **xgb_params}
In [23]:
from sklearn.model_selection import RandomizedSearchCV

cv = 3
n_iter = 3
random_state = 1234
scoring = 'accuracy'

all_models = [
    ('w2v_xgb', w2v_xgb, w2v_xgb_params),
    ('tfidf_logistic', tfidf_logistic, tfidf_logistic_params)
]

all_models_info = []
for name, model, params in all_models:
    print('training:', name)
    model_tuned = RandomizedSearchCV(
        estimator=model,
        param_distributions=params,
        cv=cv,
        n_iter=n_iter,
        n_jobs=-1,
        verbose=1,
        scoring=scoring,
        random_state=random_state,
        return_train_score=False
    ).fit(X_train, y_train)
    
    y_test_pred = model_tuned.predict(X_test)
    test_score = accuracy_score(y_test, y_test_pred)
    info = name, model_tuned.best_score_, test_score, model_tuned
    all_models_info.append(info)

columns = ['model_name', 'train_score', 'test_score', 'estimator']
results = pd.DataFrame(all_models_info, columns=columns)
results = (results
           .sort_values('test_score', ascending=False)
           .reset_index(drop=True))
results
training: w2v_xgb
Fitting 3 folds for each of 3 candidates, totalling 9 fits
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   4 out of   9 | elapsed:  2.0min remaining:  2.4min
[Parallel(n_jobs=-1)]: Done   9 out of   9 | elapsed:  2.3min finished
training: tfidf_logistic
Fitting 3 folds for each of 3 candidates, totalling 9 fits
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.
[Parallel(n_jobs=-1)]: Done   4 out of   9 | elapsed:    1.4s remaining:    1.7s
[Parallel(n_jobs=-1)]: Done   9 out of   9 | elapsed:    2.0s finished
Out[23]:
model_name train_score test_score estimator
0 tfidf_logistic 0.949175 0.962240 RandomizedSearchCV(cv=3, error_score='raise-de...
1 w2v_xgb 0.954532 0.958333 RandomizedSearchCV(cv=3, error_score='raise-de...

Note that different run may result in different performance being reported. And as our dataset changes, different approaches might that worked the best on one dataset might no longer be the best. Especially since the dataset we're working with here isn't very big, training an embedding from scratch will most likely not reach its full potential.

There are many other text classification techniques in the deep learning realm that we haven't yet explored, we'll leave that for another day.

Reference