In [1]:
# code for loading the format for the notebook
import os

# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', 'notebook_format'))

from formats import load_style
load_style(css_style='custom2.css', plot_style=False)
Out[1]:
In [2]:
os.chdir(path)

# 1. magic for inline plot
# 2. magic to print version
# 3. magic so that the notebook will reload external python modules
# 4. magic to enable retina (high resolution) plots
# https://gist.github.com/minrk/3301035
%matplotlib inline
%load_ext watermark
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

import time
import nmslib  # pip install nmslib>=1.7.3.2 pybind11>=2.2.3
import zipfile
import requests
import numpy as np
import matplotlib.pyplot as plt
from tqdm import trange
from joblib import dump, load
from sklearn.preprocessing import normalize
from sklearn.neighbors import NearestNeighbors
from sklearn.model_selection import train_test_split

# change default style figure and font size
plt.rcParams['figure.figsize'] = 8, 6
plt.rcParams['font.size'] = 12

%watermark -a 'Ethen' -d -t -v -p numpy,sklearn,matplotlib,tqdm,nmslib
Ethen 2018-08-21 18:38:17 

CPython 3.6.4
IPython 6.4.0

numpy 1.14.1
sklearn 0.19.1
matplotlib 2.2.2
tqdm 4.24.0
nmslib 1.7.3.4

Approximate Nearest Neighbor Search

Approximate nearest neighbor (ANN) search is useful when we have a large dataset with hundred thousands/millions/billions of data-points, and for a given data point we wish to find its nearest neighbors. There are many use case for this type of methods and the one we'll be focusing on here is finding similar vector representations, so think algorithms such as matrix factorization or word2vec that compresses our original data into embeddings, or so called latent factors. And throughout the notebook, the notion of similar here will be referring to two vectors' cosine distance.

There are many open-source implementations already that we can use to see whether it solves our problem, but the question is always which one is better? The following github repo contains a thorough benchmarks of various open-sourced implementations. Github: Benchmarking nearest neighbors.

The goal of this notebook shows how to run a quicker benchmark ourselves without all the complexity. The repo listed above benchmarks multiple algorithms on multiple datasets using multiple hyperparameters, which can take a really long time. We will pick one of the open-source implementation that has been identified as a solid choice and walk through step-by-step of the process using one dataset.

Setting Up the Data

The first step is to get our hands on some data and split it into training and test set, here we'll be using the glove vector representation trained on twitter dataset.

In [3]:
def download(url, filename):
    with open(filename, 'wb') as file:
        response = requests.get(url)
        file.write(response.content)

# we'll download the data to DATA_DIR location
DATA_DIR = './datasets/'
URL = 'http://nlp.stanford.edu/data/glove.twitter.27B.zip'
filename = os.path.join(DATA_DIR, 'glove.twitter.27B.zip')
if not os.path.exists(DATA_DIR):
    os.makedirs(DATA_DIR)

if not os.path.exists(filename):
    download(URL, filename)
In [4]:
def get_train_test_data(filename, dimension=25, test_size=0.2, random_state=1234):
    """
    dimension : int, {25, 50, 100, 200}, default 25
        The dataset contains embeddings of different size.
    """
    with zipfile.ZipFile(filename) as f:
        X = []
        zip_filename = 'glove.twitter.27B.{}d.txt'.format(dimension)
        for line in f.open(zip_filename):
            # remove the first index, id field and only get the vectors
            vector = np.array([float(x) for x in line.strip().split()[1:]])
            X.append(vector)

        X_train, X_test = train_test_split(
            np.array(X), test_size=test_size, random_state=random_state)

    # we can downsample for experimentation purpose
    # X_train = X_train[:50000]
    # X_test = X_test[:10000]
    return X_train, X_test


X_train, X_test = get_train_test_data(filename)
print('training data shape: ', X_train.shape)
print('testing data shape: ', X_test.shape)
training data shape:  (954811, 25)
testing data shape:  (238703, 25)

Benchmarking an approximate nearest neighbor method involves looking at how much faster it is compared to exact nearest neighbor methods and how much precision/recall are we losing for the speed that was gained. To measure this, we first need to use an exact nearest neighbor methods to see how long it takes and store the ground truth. e.g. if out exact nearest neighbor methods, thinks that for data point 1, its top 3 nearest neighbors excluding itself are [2, 4, 1], and our approximate nearest neighbor method returns [2, 1, 5], then our precision/recall depending on which way we're looking at it would be 66%, since 2 and 1 are both in the ground truth set whereas 5 is not.

In [5]:
class BruteForce:
    """
    Brute force way of computing cosine distance, this
    is more of clarifying what we're trying to accomplish,
    don't actually use it as it will take extremely long.
    """

    def __init__(self):
        pass

    def fit(self, X):
        lens = (X ** 2).sum(axis=-1)
        index = X / np.sqrt(lens)[:, np.newaxis]
        self.index_ = np.ascontiguousarray(index, dtype=np.float32)
        return self

    def query(self, vector, topn):
        """Find indices of most similar vectors for a given query vector."""

        # argmax_a dot(a, b) / |a||b| = argmin_a -dot(a, b)
        dists = -np.dot(self.index_, vector)
        indices = np.argpartition(dists, topn)[:topn]
        return sorted(indices, key=lambda index: dists[index])


class KDTree:

    def __init__(self, topn=10, n_jobs=-1):
        self.topn = topn
        self.n_jobs = n_jobs

    def fit(self, X):

        # cosine distance is proportional to normalized euclidean distance,
        # thus we normalize the item vectors and use euclidean metric so
        # we can use the more efficient kd-tree for nearest neighbor search
        X_normed = normalize(X)
        index = NearestNeighbors(
            n_neighbors=self.topn, metric='euclidean', n_jobs=self.n_jobs)
        index.fit(X_normed)
        self.index_ = index
        return self

    def query_batch(self, X):
        X_normed = normalize(X)
        _, indices = self.index_.kneighbors(X_normed)
        return indices

    def query(self, vector):
        vector_normed = normalize(vector.reshape(1, -1))
        _, indices = self.index_.kneighbors(vector_normed)
        return indices.ravel()
In [6]:
def get_ground_truth(X_train, X_test, kdtree_params):
    """
    Compute the ground truth or so called golden standard, during
    which we'll compute the time to build the index using the
    training set, time to query the nearest neighbors for all
    the data points in the test set. The ground_truth returned
    will be of type list[(ndarray, ndarray)], where the first
    ndarray will be the query vector, and the second ndarray will
    be the corresponding nearest neighbors.
    """
    start = time.time()
    kdtree = KDTree(**kdtree_params)
    kdtree.fit(X_train)
    build_time = time.time() - start

    start = time.time()
    indices = kdtree.query_batch(X_test)
    query_time = time.time() - start

    ground_truth = [(vector, index) for vector, index in zip(X_test, indices)]
    return build_time, query_time, ground_truth
In [7]:
# we'll compute the ground truth for the first time and
# store it on disk to prevent computing it over and over again
MODEL_DIR = 'model'
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)

ground_truth_filename = 'ground_truth.pkl'
ground_truth_filepath = os.path.join(MODEL_DIR, ground_truth_filename)
print('ground truth filepath: ', ground_truth_filepath)

if os.path.exists(ground_truth_filepath):
    ground_truth = load(ground_truth_filepath)
else:
    # using a setting of kdtree_params = {'topn': 10, 'n_jobs': -1},
    # it took at least 1 hour to finish on a 8 core machine
    kdtree_params = {'topn': 10, 'n_jobs': -1}
    build_time, query_time, ground_truth = get_ground_truth(X_train, X_test, kdtree_params)
    print('build time: ', build_time)
    print('query time: ', query_time)
    dump(ground_truth, ground_truth_filepath)

ground_truth[0]
ground truth filepath:  model/ground_truth.pkl
build time:  5.02460503578186
query time:  5105.871987104416
Out[7]:
(array([ 0.84227,  0.19005,  1.5346 ,  0.88995, -1.6548 , -0.60046,
        -1.3206 , -1.5521 , -0.30763, -0.56361,  1.5054 ,  3.2881 ,
         1.7582 , -0.63313, -0.48781,  2.0016 , -2.5334 ,  1.0601 ,
        -0.19666, -0.38252,  0.65653,  0.89475,  2.7882 ,  2.4109 ,
        -0.72981]),
 array([213945, 566700, 232533, 673941,  79801, 932371,  59183, 318977,
        649659, 871934]))

Benchmarking ANN Methods

The library that we'll be leveraging here is nmslib, specifically the algorithm HNSW (Hierarchical Navigable Small World), a graph-based approximate nearest neighborhood search method, we will only be using the library and will not be introducing the details of the algorithm in this notebook.

In [8]:
class Hnsw:

    def __init__(self, space='cosinesimil', index_params=None,
                 query_params=None, print_progress=True):
        self.space = space
        self.index_params = index_params
        self.query_params = query_params
        self.print_progress = print_progress

    def fit(self, X):
        index_params = self.index_params
        if index_params is None:
            index_params = {'M': 16, 'post': 0, 'efConstruction': 400}

        query_params = self.query_params
        if query_params is None:
            query_params = {'ef': 90}

        # this is the actual nmslib part, hopefully the syntax should
        # be pretty readable, the documentation also has a more verbiage
        # introduction: https://nmslib.github.io/nmslib/quickstart.html
        index = nmslib.init(space=self.space, method='hnsw')
        index.addDataPointBatch(X)
        index.createIndex(index_params, print_progress=self.print_progress)
        index.setQueryTimeParams(query_params)

        self.index_ = index
        self.index_params_ = index_params
        self.query_params_ = query_params
        return self

    def query(self, vector, topn):
        # the knnQuery returns indices and corresponding distance
        # we will throw the distance away for now
        indices, _ = self.index_.knnQuery(vector, k=topn)
        return indices

Like a lot of machine learning algorithms, there are hyperparameters that we can tune. We will pick a random one for now and look at the influence of each hyperparameters in later section.

In [9]:
index_params = {'M': 5, 'post': 0, 'efConstruction': 100}

start = time.time()
hnsw = Hnsw(index_params=index_params)
hnsw.fit(X_train)
build_time = time.time() - start
build_time
Out[9]:
42.73225116729736

we'll first use the first element from the ground truth to show-case what we'll be doing before scaling it to all the data points.

In [10]:
topn = 10

query_vector, correct_indices = ground_truth[0]
start = time.time()

# use the query_vector to find its corresponding
# approximate nearest neighbors
found_indices = hnsw.query(query_vector, topn)

query_time = time.time() - start
print('query time:', query_time)

print('correct indices: ', correct_indices)
print('found indices: ', found_indices)
query time: 0.0002560615539550781
correct indices:  [213945 566700 232533 673941  79801 932371  59183 318977 649659 871934]
found indices:  [213945 566700 232533 673941  79801 318977 871934 221617 107727 705332]
In [11]:
# compute the proportion of data points that overlap between the
# two sets
precision = len(set(found_indices).intersection(correct_indices)) / topn
precision
Out[11]:
0.7
In [12]:
def run_algo(X_train, X_test, topn, ground_truth, algo_type='hnsw', algo_params=None):
    """
    We can extend this benchmark across multiple algorithm or algorithm's hyperparameter
    by adding more algo_type options. The algo_params can be a dictionary that is passed
    to the algorithm's __init__ method.
    Here only 1 method is included.
    """

    if algo_type == 'hnsw':
        algo = Hnsw()
        if algo_params is not None:
            algo = Hnsw(**algo_params)

    start = time.time()
    algo.fit(X_train)
    build_time = time.time() - start

    total_correct = 0
    total_query_time = 0.0
    n_queries = len(ground_truth)
    for i in trange(n_queries):
        query_vector, correct_indices = ground_truth[i]

        start = time.time()
        found_indices = algo.query(query_vector, topn)
        query_time = time.time() - start
        total_query_time += query_time

        n_correct = len(set(found_indices).intersection(correct_indices))
        total_correct += n_correct

    avg_query_time = total_query_time / n_queries
    avg_precision = total_correct / (n_queries * topn)
    return build_time, avg_query_time, avg_precision

The next few code chunks experiments with different parameters to see which one works better for this use-case.

Recommended by the author of package, the most influential parameters are M and efConstruction.

  • efConstruction: Increasing this value improves the quality of the constructed graph and leads to a higher search accuracy, at the cost of longer indexing time. The same idea applies to the ef or efSearch parameter that we can pass to query_params. Reasonable range for this parameter is 100-2000.
  • M: This parameter controls the maximum number of neighbors for each layer. Increasing the values of this parameters (to a certain degree) leads to better recall and shorter retrieval times (at the expense of longer indexing time). Reasonable range for this parameter is 5-100.

Other parameters include indexThreadQty (we can explicitly set the number of threads) and post. The post parameter controls the amount of post-processing done to the graph. 0, which means no post-processing. Additional options are 1 and 2 (2 means more post-processing).

In [13]:
# we will be running four combinations, higher/lower
# efConstruction/M parameters and comparing the performance
algo_type = 'hnsw'
algo_params = {
    'index_params': {'M': 16, 'post': 0, 'efConstruction': 100}
}

build_time1, avg_query_time1, avg_precision1 = run_algo(
    X_train, X_test, topn, ground_truth, algo_type, algo_params)

print('build time: ', build_time1)
print('average search time: ', avg_query_time1)
print('average precision: ', avg_precision1)
100%|██████████| 238703/238703 [00:42<00:00, 5642.16it/s]
build time:  96.66552662849426
average search time:  0.00014930271766456666
average precision:  0.971047284701072

In [14]:
algo_params = {
    'index_params': {'M': 16, 'post': 0, 'efConstruction': 400}
}

build_time2, avg_query_time2, avg_precision2 = run_algo(
    X_train, X_test, topn, ground_truth, algo_type, algo_params)

print('build time: ', build_time2)
print('average search time: ', avg_query_time2)
print('average precision: ', avg_precision2)
100%|██████████| 238703/238703 [00:45<00:00, 5257.06it/s]
build time:  312.3543019294739
average search time:  0.0001598479527504984
average precision:  0.9770271006229498

In [15]:
algo_params = {
    'index_params': {'M': 5, 'post': 0, 'efConstruction': 100}
}

build_time3, avg_query_time3, avg_precision3 = run_algo(
    X_train, X_test, topn, ground_truth, algo_type, algo_params)

print('build time: ', build_time3)
print('average search time: ', avg_query_time3)
print('average precision: ', avg_precision3)
100%|██████████| 238703/238703 [00:23<00:00, 10106.25it/s]
build time:  40.743391036987305
average search time:  7.756965745461889e-05
average precision:  0.7929644788712333
In [16]:
algo_params = {
    'index_params': {'M': 5, 'post': 0, 'efConstruction': 400}
}

build_time4, avg_query_time4, avg_precision4 = run_algo(
    X_train, X_test, topn, ground_truth, algo_type, algo_params)

print('build time: ', build_time4)
print('average search time: ', avg_query_time4)
print('average precision: ', avg_precision4)
100%|██████████| 238703/238703 [00:24<00:00, 9667.32it/s]
build time:  135.51102113723755
average search time:  8.058855207119328e-05
average precision:  0.8155624353275828

Based on the result, we can see that larger values of parameters M and efConstruction does give better precision scores. Another observation is that the result for efConstruction = 100 is on-par with efConstruction = 400 and only one third of the time to build the index.

Reference