None nsw
In [1]:
# code for loading the format for the notebook
import os

# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', 'notebook_format'))

from formats import load_style
load_style(plot_style=False)
Out[1]:
In [2]:
os.chdir(path)

# 1. magic for inline plot
# 2. magic to print version
# 3. magic so that the notebook will reload external python modules
# 4. magic to enable retina (high resolution) plots
# https://gist.github.com/minrk/3301035
%matplotlib inline
%load_ext watermark
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format='retina'

import time
import fasttext
import numpy as np
import pandas as pd

# prevent scientific notations
pd.set_option('display.float_format', lambda x: '%.3f' % x)

%watermark -a 'Ethen' -d -t -v -p numpy,pandas,fasttext,scipy
Ethen 2020-06-08 12:01:50 

CPython 3.6.4
IPython 7.15.0

numpy 1.16.5
pandas 0.25.0
fasttext n
scipy 1.4.1

Approximate Nearest Neighborhood Search with Navigable Small World

Performing nearest neighborhood search on embeddings has become a crucial process in many applications, such as similar image/text search. The ann benchmark contains benchmark on various approximate nearest neighborhood search algorithms/libraries and in this document, we'll take a look at one of them, Navigable Small World Graph.

Data Preparation and Model

For the embedding, we'll be training a fasttext multi-label text classification model ourselves, and using the output embedding for this example. The fasttext library has already been introduced in another post, hence we won't be going over it in detail. The readers can also swap out the data preparation and model section with the embedding of their liking.

In [3]:
# download the data and un-tar it under the 'data' folder

# -P or --directory-prefix specifies which directory to download the data to
!wget https://dl.fbaipublicfiles.com/fasttext/data/cooking.stackexchange.tar.gz -P data
# -C specifies the target directory to extract an archive to
!tar xvzf data/cooking.stackexchange.tar.gz -C data
--2020-06-08 12:01:51--  https://dl.fbaipublicfiles.com/fasttext/data/cooking.stackexchange.tar.gz
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 172.67.9.4, 104.22.74.142, 104.22.75.142
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|172.67.9.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 457609 (447K) [application/x-tar]
Saving to: ‘data/cooking.stackexchange.tar.gz.1’

cooking.stackexchan 100%[===================>] 446.88K  --.-KB/s    in 0.08s   

2020-06-08 12:01:51 (5.53 MB/s) - ‘data/cooking.stackexchange.tar.gz.1’ saved [457609/457609]

x cooking.stackexchange.id
x cooking.stackexchange.txt
x readme.txt
In [4]:
!head -n 3 data/cooking.stackexchange.txt
__label__sauce __label__cheese How much does potato starch affect a cheese sauce recipe?
__label__food-safety __label__acidity Dangerous pathogens capable of growing in acidic environments
__label__cast-iron __label__stove How do I cover up the white spots on my cast iron stove?
In [5]:
# train/test split
import os
from fasttext_module.split import train_test_split_file
from fasttext_module.utils import prepend_file_name

data_dir = 'data'
test_size = 0.2
input_path = os.path.join(data_dir, 'cooking.stackexchange.txt')
input_path_train = prepend_file_name(input_path, 'train')
input_path_test = prepend_file_name(input_path, 'test')
random_state = 1234
encoding = 'utf-8'

train_test_split_file(input_path, input_path_train, input_path_test,
                      test_size, random_state, encoding)
print('train path: ', input_path_train)
print('test path: ', input_path_test)
train path:  data/train_cooking.stackexchange.txt
test path:  data/test_cooking.stackexchange.txt
In [6]:
# train the fasttext model
fasttext_params = {
    'input': input_path_train,
    'lr': 0.1,
    'lrUpdateRate': 1000,
    'thread': 8,
    'epoch': 15,
    'wordNgrams': 1,
    'dim': 80,
    'loss': 'ova'
}
model = fasttext.train_supervised(**fasttext_params)

print('vocab size: ', len(model.words))
print('label size: ', len(model.labels))
print('example vocab: ', model.words[:5])
print('example label: ', model.labels[:5])
vocab size:  14496
label size:  733
example vocab:  ['</s>', 'to', 'a', 'How', 'the']
example label:  ['__label__baking', '__label__food-safety', '__label__substitutions', '__label__equipment', '__label__bread']
In [7]:
# model.get_input_matrix().shape
print('output matrix shape: ', model.get_output_matrix().shape)
model.get_output_matrix()
output matrix shape:  (733, 80)
Out[7]:
array([[ 4.7899528 , -0.6933957 ,  0.39464658, ..., -2.0341725 ,
        -0.7517707 ,  0.3983426 ],
       [ 2.9305046 , -0.28570035, -2.3910296 , ...,  2.1693978 ,
         0.47595456, -1.6293081 ],
       [ 3.9586446 ,  0.00725545,  0.11945528, ..., -3.3996897 ,
         0.94858617,  1.26207   ],
       ...,
       [ 1.1076808 , -0.49902833,  0.36399806, ..., -0.911734  ,
        -0.25994965,  0.4118186 ],
       [ 1.2184464 , -0.48066193,  0.4087498 , ..., -0.9224677 ,
        -0.21687554,  0.4334651 ],
       [ 1.1682702 , -0.48476282,  0.4047362 , ..., -0.9493868 ,
        -0.29047066,  0.44329277]], dtype=float32)

Given the output matrix, we would like to compute each of its nearest neighbors using the compressed vectors.

For those that are more interested in using some other embeddings, replace the index_factors with the embedding, and query_factors with a random element from that set of embeddings, and the rest of the document should still function properly.

In [8]:
# we'll get one of the labels to find its nearest neighbors 
label_id = 0
print(model.labels[label_id])

index_factors = model.get_output_matrix()
query_factors = model.get_output_matrix()[label_id]
query_factors.shape
__label__baking
Out[8]:
(80,)

We'll start off by formally defining the problem. k-nearest neighbor search is a problem where given a query object $q$ we need to find the $k$ closest objects from a fixed set of objects $O \in D$, where $D$ is the set of all possible objects at hand.

The idea behind navigable small world is to use a graph data structure $G(V, E)$ to represent these objects $O$, where every object $o_i$ is represented by a vertex/node $v_i$. The navigable small world graph structure is constructed by sequential addition of all elements. For every new element, we find the set of its closest neighbors using a variant of the greedy search algorithm, upon doing so, we'll then introduce a bidirectional connection between that set of neighbors and the incoming element.

Upon building the graph, searching for the closest objects to $q$ is very similar to adding objects to the graph. i.e. It involves traversing through the graph to find the closest vertices/nodes using the same variant of greedy search algorithm that's used when constructing the graph.

Another thing worth noting is that determining closest neighbors is dependent on a distance function. As the algorithm doesn't make any strong assumption about the data, it can be used on any distance function of our likings. Here we'll be using the cosine distance as an illustration.

In [9]:
class Node:
    """
    Node for a navigable small world graph.

    Parameters
    ----------
    idx : int
        For uniquely identifying a node.

    value : 1d np.ndarray
        To access the embedding associated with this node.

    neighborhood : set
        For storing adjacent nodes.

    References
    ----------
    https://book.pythontips.com/en/latest/__slots__magic.html
    https://hynek.me/articles/hashes-and-equality/
    """
    __slots__ = ['idx', 'value', 'neighborhood']

    def __init__(self, idx, value):
        self.idx = idx
        self.value = value
        self.neighborhood = set()

    def __hash__(self):
        return hash(self.idx)

    def __eq__(self, other):
        return (
            self.__class__ == other.__class__ and
            self.idx == other.idx
        )
In [10]:
from scipy.spatial import distance


def build_nsw_graph(index_factors, k):
    n_nodes = index_factors.shape[0]

    graph = []
    for i, value in enumerate(index_factors):
        node = Node(i, value)
        graph.append(node)

    for node in graph:
        query_factor = node.value.reshape(1, -1)

        # note that the following implementation is not the actual procedure that's
        # used to find the k closest neighbors, we're just implementing a quick version,
        # will come back to this later

        # https://codereview.stackexchange.com/questions/55717/efficient-numpy-cosine-distance-calculation
        # the smaller the cosine distance the more similar, thus the most
        # similar item will be the first element after performing argsort
        # since argsort by default sorts in ascending order
        dist = distance.cdist(index_factors, query_factor, metric='cosine').ravel()
        neighbors_indices = np.argsort(dist)[:k].tolist()
        
        # insert bi-directional connection
        node.neighborhood.update(neighbors_indices)
        for i in neighbors_indices:
            graph[i].neighborhood.add(node.idx)

    return graph
In [11]:
k = 10

graph = build_nsw_graph(index_factors, k)
graph[0].neighborhood
Out[11]:
{0, 119, 123, 144, 179, 187, 199, 204, 221, 399}

In the original paper, the author used the term "friends" of vertices that share an edge, and "friend list" of vertex $v_i$ for the list of vertices that share a common with the vertex $v_i$.

We'll now introduce the variant of greedy search that the algorithm uses. The pseudocode looks like the following:

greedy_search(q: object, v_entry_point: object):
    v_curr = v_entry_point
    d_min = dist_func(q, v_current)
    v_next = None

    for v_friend in v_curr.get_friends():
        d_friend = dist_func(q, v_friend)
        if d_friend < d_min:
            d_min = d_friend
            v_next = v_friend

    if v_next is None:
        return v_curr
    else:
        return greedy_search(q, v_next)

Where starting from some entry point (chosen at random at the beginning), the greedy search algorithm computes a distance from the input query to each of the current entry point's friend vertices. If the distance between the query and the friend vertex is smaller than the current ones, then the greedy search algorithm will move to the vertex and repeats the process until it can't find a friend vertex that is closer to the query than the current vertex.

This approach can of course lead to local minimum, i.e. the closest vertex/object determined by this greedy search algorithm is not the actual true closest element to the incoming query. Hence, the idea to extend this is to pick a series of entry point, denoted by m in the pseudocode below and return the best results from all those greedy searches. With each additional search, the chances of not finding the true nearest neighbors should decrease exponentially.

The key idea behind the knn search is given a random entry point, it iterates on vertices closest to the query that we've never previously visited. And the algorithm keeps greedily exploring the neighborhood until the $k$ nearest elements can't be improved upon. Then this process repeats for the next random entry point.

knn_search(q: object, m: int, k: int):
    queue[object] candidates, temp_result, result
    set[object] visited_set

    for i in range(m):
        put random entry point in candidates
        temp_result = None

        repeat:
            get element c closet from candidate to q
            remove c from candidates

            if c is further than the k-th element from result:
                break repeat

            for every element e from friends of c:
                if e is not visited_set:
                    add e to visited_set, candidates, temp_result


        add objects from temp_result to result

    return best k elements from result


We'll be using the heapq module as our priority queue.

In [12]:
import heapq
import random
from typing import List, Tuple


def nsw_knn_search(
    graph: List[Node],
    query: np.ndarray,
    k: int=5,
    m: int=50) -> Tuple[List[Tuple[float, int]], float]:
    """
    Performs knn search using the navigable small world graph.

    Parameters
    ----------
    graph :
        Navigable small world graph from build_nsw_graph.

    query : 1d np.ndarray
        Query embedding that we wish to find the nearest neighbors.

    k : int
        Number of nearest neighbors returned.

    m : int
        The recall set will be chosen from m different entry points.

    Returns
    -------
    The list of nearest neighbors (distance, index) tuple.
    and the average number of hops that was made during the search.
    """
    result_queue = []
    visited_set = set()
    
    hops = 0
    for _ in range(m):
        # random entry point from all possible candidates
        entry_node = random.randint(0, len(graph) - 1)
        entry_dist = distance.cosine(query, graph[entry_node].value)
        candidate_queue = []
        heapq.heappush(candidate_queue, (entry_dist, entry_node))

        temp_result_queue = []
        while candidate_queue:
            candidate_dist, candidate_idx = heapq.heappop(candidate_queue)

            if len(result_queue) >= k:
                # if candidate is further than the k-th element from the result,
                # then we would break the repeat loop
                current_k_dist, current_k_idx = heapq.nsmallest(k, result_queue)[-1]
                if candidate_dist > current_k_dist:
                    break

            for friend_node in graph[candidate_idx].neighborhood:
                if friend_node not in visited_set:
                    visited_set.add(friend_node)

                    friend_dist = distance.cosine(query, graph[friend_node].value)
                    heapq.heappush(candidate_queue, (friend_dist, friend_node))
                    heapq.heappush(temp_result_queue, (friend_dist, friend_node))
                    hops += 1

        result_queue = list(heapq.merge(result_queue, temp_result_queue))

    return heapq.nsmallest(k, result_queue), hops / m
In [13]:
results = nsw_knn_search(graph, query_factors, k=5)
results
Out[13]:
([(0.0, 0),
  (0.24616026878356934, 221),
  (0.26159465312957764, 199),
  (0.3747814893722534, 187),
  (0.38302379846572876, 179)],
 14.66)

Now that we've implemented the knn search algorithm, we can go back and modify the graph building function and use it to implement the actual way of building the navigable small world graph.

In [14]:
def build_nsw_graph(index_factors: np.ndarray, k: int) -> List[Node]:
    n_nodes = index_factors.shape[0]

    graph = []
    for i, value in enumerate(index_factors):
        node = Node(i, value)
        if i > k:
            neighbors, hops = nsw_knn_search(graph, node.value, k)
            neighbors_indices = [node_idx for _, node_idx in neighbors]
        else:
            neighbors_indices = list(range(i))

        # insert bi-directional connection
        node.neighborhood.update(neighbors_indices)
        for i in neighbors_indices:
            graph[i].neighborhood.add(node.idx)
        
        graph.append(node)

    return graph
In [15]:
k = 10

index_factors = model.get_output_matrix()
graph = build_nsw_graph(index_factors, k)
graph[0].neighborhood
Out[15]:
{1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 15,
 16,
 17,
 21,
 23,
 24,
 27,
 37,
 38,
 39,
 40,
 52,
 54,
 59,
 119,
 123,
 144,
 199,
 221}
In [16]:
results = nsw_knn_search(graph, query_factors, k=5)
results
Out[16]:
([(0.0, 0),
  (0.24616026878356934, 221),
  (0.26159465312957764, 199),
  (0.3747814893722534, 187),
  (0.38302379846572876, 179)],
 14.66)

Hnswlib

We can check our result with a more advanced variant of the algorithm, Hierarchical Navigable Small World (HNSW) provided by hnswlib. The idea is very similar to skip list data structure, except we now replace its link list with nagivable small world graphs. Although we never formally introduce this new hierarchical variant, but hopefully all its major parameters should look familiar.

  • ef: This algorithm searches for ef closest neighbors to the inserted element $q$, this variable was set to $k$ in the original navigable small world paper. These ef closest neighbors then becomes the candidate/recall set for inserting bidirectional edges during insertion/construction phase (which is termed ef_construction) or after we're done with constructing our graph, these are our candidate/recall set for finding actual top k closest elements to the input query object.
  • M: After choosing ef_construction objects, only M closest ones will we create edges between the enter point and those nodes. i.e. it controls the number of bi-directional links.

The actual process of constructing HNSW and doing knn search is a bit more involved compared to vanilla navigable small world. We won't be getting into all the gory details in this post.

In [17]:
import hnswlib


def build_hnsw(factors, space, ef_construction, M):
    # Declaring index
    max_elements, dim = factors.shape
    hnsw = hnswlib.Index(space, dim) # possible options for space are l2, cosine or ip

    # Initing index - the maximum number of elements should be known beforehand
    hnsw.init_index(max_elements, M, ef_construction)

    # Element insertion (can be called several times)
    hnsw.add_items(factors)
    return hnsw
In [18]:
space = 'cosine'
ef_construction = 200
M = 24

start = time.time()
hnsw = build_hnsw(index_factors, space, ef_construction, M)
build_time = time.time() - start
build_time
Out[18]:
0.013833045959472656
In [19]:
k = 5

# Controlling the recall by setting ef, should always be > k
hnsw.set_ef(70)

# retrieve the top-n search neighbors
labels, distances = hnsw.knn_query(query_factors, k=k)
print(labels)
[[  0 221 199 187 179]]
In [20]:
# find the nearest neighbors and "translate" it to the original labels
[model.labels[label] for label in labels[0]]
Out[20]:
['__label__baking',
 '__label__baking-soda',
 '__label__baking-powder',
 '__label__muffins',
 '__label__cheesecake']

Based on the ann benchmark, Hierarchical Navigable Small World (HNSW) stood out as one of the top performing approximate nearest neighborhood algorithms at the time of writing this document. Here, we introduced the vanilla variant of that algorithm, Navigable Small World and also matched the result with a more robust implementation from the open sourced library hnswlib.

Reference