Faiss 源码解析

image.png

Faiss 源码解析

faissfacebook 开源的一个专门用于做高维向量的相似性搜索的库,有 c++python 的接口;目前项目地址在 https://github.com/facebookresearch/faiss。本文主要结合 faiss 的官方示例,介绍如何使用 faiss 以及 暴力/IVF/IVFPQ 检索算法在 faiss 的具体实现。

检索算法介绍

检索算法的介绍可以参考 科普,本文主要关注3种检索算法:

  1. 暴力搜索:顾名思义,querybase 一一比对,选择最近的
  2. IVF:首先在具有代表性的数据上训练聚类中心,然后将 base 加入到最近的聚类中心的桶里,在 search 的时候,query 先和聚类中心比对,再在一定数目的桶里做暴力搜索
  3. IVFPQ:在 IVF 的基础上,将 basePQ 量化,加速比对

faiss 的编译与安装

可以参考官方给出的编译方法,这里我没有安装 cuda,所以采用的命令是

./configure --without-cuda && make

在编译完 faiss 之后,我们对官方提供的示例也进行编译,路径在 ./tutorial/cpp 下,cd到目录下直接 make 就可以了

如何使用 faiss

官方总共提供了五个示例,其中有两个是 gpu 版本的,三个是 cpu 版本的,我们这里主要关注 cpu 的,分别是 1-Flat.cpp2-IVFFLAT.cpp3-IVFPQ.cpp,分别对应着暴力算法检索,IVF 算法检索,IVFPQ 算法检索。不同的算法在用户侧代码基本一致,我们选取 IVFPQ 做简单介绍。

#include 
#include 

#include 
#include 


int main() {
    int d = 64;                            // 特征维度
    int nb = 100000;                       // base 样本数量
    int nq = 10000;                        // query 样本数量

    float *xb = new float[d * nb];
    float *xq = new float[d * nq];

    for(int i = 0; i < nb; i++) {
        for(int j = 0; j < d; j++)
            xb[d * i + j] = drand48();
        xb[d * i] += i / 1000.;
    } // 随机初始化 base 数据

    for(int i = 0; i < nq; i++) {
        for(int j = 0; j < d; j++)
            xq[d * i + j] = drand48();
        xq[d * i] += i / 1000.;
    }    // 随机初始化 query 数据


    int nlist = 100;  // 聚类中心个数
    int k = 4;
    int m = 8;                             // bytes per vector
    faiss::IndexFlatL2 quantizer(d);       // 初始化用 L2 暴力 search 的 index
    faiss::IndexIVFPQ index(&quantizer, d, nlist, m, 8); // 初始化 ivfpq 的 index,用 L2 暴力 search 的 index 初始化
    index.train(nb, xb); // 训练 index
    index.add(nb, xb); // 将 base 数据加入到 index 中,用于之后的搜索

    {       // search xq
        long *I = new long[k * nq];
        float *D = new float[k * nq];

        index.nprobe = 10; // 搜索 10 个中心点
        index.search(nq, xq, k, D, I);

        printf("I=\n");
        for(int i = nq - 5; i < nq; i++) {
            for(int j = 0; j < k; j++)
                printf("%5ld ", I[i * k + j]);
            printf("\n");
        }

        delete [] I;
        delete [] D;
    }



    delete [] xb;
    delete [] xq;

    return 0;
}

这段代码主要包括了四个部分,分别是

  1. 初始化 base/query 数据和 index
  2. 训练 index
  3. 加入baseindex
  4. querysearch

其中,使用 faiss 主要包含了三步。初始化数据准备不用多说,faiss 中要求的数据格式都是 n * d 的矩阵格式,然后被展平到一维 float 数组中。剩下的两步,都是对 index 进行操作。

源码解析

检索流程

参考官方给的例子,检索分为三步:trainaddsearch,不同的检索算法,体现在使用不同的 index 进行这三步上

  1. train:选取有代表性的数据,训练 index
  2. add:将 base 数据加入到 index
  3. search:对于给定的 query,返回其对应的在底库中的 topk

重要类

Index

index 的基类,后续各种各样的检索算法,都会继承这个基类或者这个类的派生类,然后实现具体的方法,在这个类中,有如下的数据成员:

  • d:维度,每个向量的维度
  • ntotal:索引的向量的数目,可以理解成检索时的 base 数目
  • metric_type:检索时使用的 metric 类型,比如 L2,内积等

IndexFlat

用于做暴力搜索的 index 类,直接继承 index。暴力搜索思路很简单,无需 trainadd 的所有 base 都被存储起来,然后在 search 的时候把 query 和所有 base 进行比对,选取最近的。我们看下具体实现。

  • add

add 就是把所有的 base 都存储起来

void IndexFlat::add (idx_t n, const float *x) {
    xb.insert(xb.end(), x, x + n * d);
    ntotal += n;
}

  • Search

Search 的时候,根据 metric type 的不同,返回 querytopk。具体计算时采用了 openmpsse/avx 优化

void IndexFlat::search (idx_t n, const float *x, idx_t k,
                               float *distances, idx_t *labels) const
{
    // we see the distances and labels as heaps

    if (metric_type == METRIC_INNER_PRODUCT) {
        float_minheap_array_t res = {
            size_t(n), size_t(k), labels, distances};
        knn_inner_product (x, xb.data(), d, n, ntotal, &res); //函数内部有并行优化
    } else if (metric_type == METRIC_L2) {
        float_maxheap_array_t res = {
            size_t(n), size_t(k), labels, distances};
        knn_L2sqr (x, xb.data(), d, n, ntotal, &res);
    } else {
        float_maxheap_array_t res = {
            size_t(n), size_t(k), labels, distances};
        knn_extra_metrics (x, xb.data(), d, n, ntotal,
                           metric_type, metric_arg,
                           &res);
    }
}

Clustering

实现 K-means 聚类的类,提供train ,需要训练数据和 index(用于 search 最近的向量),结果得到训练数据的类中心向量,如果是量化的向量,那么还需要提供量化使用的 index codec,我们去除量化的部分,只看 float 数据

核心代码如下,包括如下部分:

  • search过程,将聚类中心作为底库加入到 index 中,并对训练数据做 search,得到 assign
  • 计算新的聚类中心,计算新的聚类中心的代码在 compute_centroids中,具体就是对于相同的类别的向量,将向量的均值作为新的中心,在实现上,利用 openmp 进行了并行优化

重复以上两步,就可以得到最优的聚类中心

void Clustering::train_encoded (idx_t nx, const uint8_t *x_in,
                                const Index * codec, Index & index,
                                const float *weights) {
  // 前处理省略  
  for (int redo = 0; redo < nredo; redo++) {

        if (verbose && nredo > 1) {
            printf("Outer iteration %d / %d\n", redo, nredo);
        }

        // initialize (remaining) centroids with random points from the dataset
        centroids.resize (d * k);
        std::vector perm (nx);

        rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);

        for (int i = n_input_centroids; i < k ; i++) {
          memcpy (¢roids[i * d], x + perm[i] * line_size, line_size);
        }

        post_process_centroids ();

        // prepare the index

        if (index.ntotal != 0) {
            index.reset();
        }

        index.add (k, centroids.data());

        // k-means iterations

        float err = 0;
        for (int i = 0; i < niter; i++) {
            double t0s = getmillisecs();
                        index.search (nx, reinterpret_cast(x), 1,
                          dis.get(), assign.get());

            InterruptCallback::check();
            t_search_tot += getmillisecs() - t0s;

            // accumulate error
            err = 0;
            for (int j = 0; j < nx; j++) {
                err += dis[j];
            }

            // update the centroids
            std::vector hassign (k);

            size_t k_frozen = frozen_centroids ? n_input_centroids : 0;
            compute_centroids (
                  d, k, nx, k_frozen,
                  x, codec, assign.get(), weights,
                  hassign.data(), centroids.data()
            );

            index.reset ();
            if (update_index) {
                index.train (k, centroids.data());
            }

            index.add (k, centroids.data());
            InterruptCallback::check ();
        }

    }
    //保存最优聚类中心
    if (nredo > 1) {
        centroids = best_centroids;
        iteration_stats = best_obj;
        index.reset();
        index.add(k, best_centroids.data());
    }

}

void compute_centroids (size_t d, size_t k, size_t n,
                       size_t k_frozen,
                       const uint8_t * x, const Index *codec,
                       const int64_t * assign,
                       const float * weights,
                       float * hassign,
                       float * centroids)
{
    k -= k_frozen;
    centroids += k_frozen * d;

    memset (centroids, 0, sizeof(*centroids) * d * k);

    size_t line_size = codec ? codec->sa_code_size() : d * sizeof (float);

#pragma omp parallel
    {
        int nt = omp_get_num_threads();
        int rank = omp_get_thread_num();

        // this thread is taking care of centroids c0:c1
        size_t c0 = (k * rank) / nt;
        size_t c1 = (k * (rank + 1)) / nt;
        std::vector decode_buffer (d);

        for (size_t i = 0; i < n; i++) {
            int64_t ci = assign[i];
            assert (ci >= 0 && ci < k + k_frozen);
            ci -= k_frozen;
            if (ci >= c0 && ci < c1)  {
                float * c = centroids + ci * d;
                const float * xi;
                if (!codec) {
                    xi = reinterpret_cast(x + i * line_size);
                } else {
                    float *xif = decode_buffer.data();
                    codec->sa_decode (1, x + i * line_size, xif);
                    xi = xif;
                }
                if (weights) {
                    float w = weights[i];
                    hassign[ci] += w;
                    for (size_t j = 0; j < d; j++) {
                        c[j] += xi[j] * w;
                    }
                } else {
                    hassign[ci] += 1.0;
                    for (size_t j = 0; j < d; j++) {
                        c[j] += xi[j];
                    }
                }
            }
        }

    }

#pragma omp parallel for
    for (size_t ci = 0; ci < k; ci++) {
        if (hassign[ci] == 0) {
            continue;
        }
        float norm = 1 / hassign[ci];
        float * c = centroids + ci * d;
        for (size_t j = 0; j < d; j++) {
            c[j] *= norm;
        }
    }

}

IndexIVF

用于做 IVF 搜索的 index 类。

  • train

ivf 算法会把给定的数据进行聚类,得到固定数目的聚类中心。具体的,就是 train_q1​ 的过程,train_residual 在 ivf 中是一个空函数

void IndexIVF::train (idx_t n, const float *x)
{
    train_q1 (n, x, verbose, metric_type);
    train_residual (n, x);
    is_trained = true;
}

void IndexIVF::train_residual(idx_t /*n*/, const float* /*x*/) {
  if (verbose)
    printf("IndexIVF: no residual training\n");
  // does nothing by default
}

train_q1用的是 Level1Quantizer 的具体实现,如下,对训练数据进行聚类,得到聚类中心并保存下来

void Level1Quantizer::train_q1 (size_t n, const float *x, bool verbose, MetricType metric_type)
{
    // 省略无关代码
    Clustering clus (d, nlist, cp);
    quantizer->reset();
    if (clustering_index) {
      clus.train (n, x, *clustering_index);
      quantizer->add (nlist, clus.centroids.data());
    } else {
      clus.train (n, x, *quantizer);
    }
    quantizer->is_trained = true;
}

  • add

  • 分片。根据输入的大小,按照固定的大小依次进行 add

  • 建立 invlists。根据 train得到的聚类中心(保存在 quantizer 中),每一个类中心对应 invlists 中的一个桶。

  • invlists 的桶里加入 base。利用了 openmp 进行了并行加速

void IndexIVF::add (idx_t n, const float * x)
{
    add_with_ids (n, x, nullptr);
}

void IndexIVF::add_with_ids (idx_t n, const float * x, const idx_t *xids)
{
    // do some blocking to avoid excessive allocs
    idx_t bs = 65536;
    if (n > bs) {
        for (idx_t i0 = 0; i0 < n; i0 += bs) {
            idx_t i1 = std::min (n, i0 + bs);
            if (verbose) {
                printf("   IndexIVF::add_with_ids %ld:%ld\n", i0, i1);
            }
            add_with_ids (i1 - i0, x + i0 * d,
                          xids ? xids + i0 : nullptr);
        }
        return;
    }

    std::unique_ptr idx(new idx_t[n]);
    quantizer->assign (n, x, idx.get());
    size_t nadd = 0, nminus1 = 0;

#pragma omp parallel reduction(+: nadd)
    {
        int nt = omp_get_num_threads();
        int rank = omp_get_thread_num();

        // each thread takes care of a subset of lists
        for (size_t i = 0; i < n; i++) {
            idx_t list_no = idx [i];
            if (list_no >= 0 && list_no % nt == rank) {
                idx_t id = xids ? xids[i] : ntotal + i;
                size_t ofs = invlists->add_entry (
                     list_no, id,
                     flat_codes.get() + i * code_size
                );

                dm_adder.add (i, list_no, ofs);

                nadd++;
            } else if (rank == 0 && list_no == -1) {
                dm_adder.add (i, -1, 0);
            }
        }
    }

    ntotal += n;
}

  • search

  • Search corse_dis。搜索离 query 最近的聚类中心

  • Search invlists。在最近的 nprobe 个聚类中心对应的 invlists 中进行暴力 heap 搜索,得到 topk

void IndexIVF::search (idx_t n, const float *x, idx_t k,
                         float *distances, idx_t *labels) const
{
    std::unique_ptr idx(new idx_t[n * nprobe]);
    std::unique_ptr coarse_dis(new float[n * nprobe]);

    double t0 = getmillisecs();
    quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get());
    indexIVF_stats.quantization_time += getmillisecs() - t0;

    t0 = getmillisecs();
    invlists->prefetch_lists (idx.get(), n * nprobe);

    search_preassigned (n, x, k, idx.get(), coarse_dis.get(),
                        distances, labels, false);
    indexIVF_stats.search_time += getmillisecs() - t0;
}

ProductQuantizer

用来做 PQ 量化算法的类,关于 PQ 量化算法,可以参考 pq算法。简单来说,我们需要得到用来量化的码本,然后我们可以对输入的向量进行解码和编码。得到码本的过程在 ProductQuantizer::train 中,包含

  • 将输入向量按照维度切分成 PQ 段,每段的维度是 dsub
  • 得到每段的聚类中心,这就是码本

编码和解码的过程就是将输入向量转化为码本里的 idx,可以看出,量化是存在一定的误差,其中,PQ 越大,误差越小

void ProductQuantizer::train (int n, const float * x)
{
    if (train_type != Train_shared) {
        train_type_t final_train_type;
        final_train_type = train_type;
        if (train_type == Train_hypercube ||
            train_type == Train_hypercube_pca) {
            if (dsub < nbits) {
                final_train_type = Train_default;
                printf ("cannot train hypercube: nbits=%ld > log2(d=%ld)\n",
                        nbits, dsub);
            }
        }

        float * xslice = new float[n * dsub];
        ScopeDeleter del (xslice);
        for (int m = 0; m < M; m++) {
            for (int j = 0; j < n; j++)
                memcpy (xslice + j * dsub,
                        x + j * d + m * dsub,
                        dsub * sizeof(float));

            Clustering clus (dsub, ksub, cp);

            // we have some initialization for the centroids
            if (final_train_type != Train_default) {
                clus.centroids.resize (dsub * ksub);
            }

            switch (final_train_type) {
            case Train_hypercube:
                init_hypercube (dsub, nbits, n, xslice,
                                clus.centroids.data ());
                break;
            case  Train_hypercube_pca:
                init_hypercube_pca (dsub, nbits, n, xslice,
                                    clus.centroids.data ());
                break;
            case  Train_hot_start:
                memcpy (clus.centroids.data(),
                        get_centroids (m, 0),
                        dsub * ksub * sizeof (float));
                break;
            default: ;
            }

            if(verbose) {
                clus.verbose = true;
                printf ("Training PQ slice %d/%zd\n", m, M);
            }
            IndexFlatL2 index (dsub);
            clus.train (n, xslice, assign_index ? *assign_index : index);
            set_params (clus.centroids.data(), m);
        }

    } else {

        Clustering clus (dsub, ksub, cp);

        if(verbose) {
            clus.verbose = true;
            printf ("Training all PQ slices at once\n");
        }

        IndexFlatL2 index (dsub);

        clus.train (n * M, x, assign_index ? *assign_index : index);
        for (int m = 0; m < M; m++) {
            set_params (clus.centroids.data(), m);
        }

    }
}

IndexIVFPQ

ivfpq 算法在 ivf 的基础上,对 basepq。大家可以自行参考代码

你可能感兴趣的:(Faiss 源码解析)