摘要: 这里主要讲整个实现过程与核心思路。
前面讲的IndexFlatL2的索引方式,主要就是一种暴力搜索的方式,只是在计算的过程中针对不同的平台进行了指令集优化。
这里的IndexIVFFlat索引主要
code
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include
#include
#include
#include
#include
#include
int main() {
int d = 64; // dimension
int nb = 100000; // database size
int nq = 10000; // nb of queries
float *xb = new float[d * nb];
float *xq = new float[d * nq];
for(int i = 0; i < nb; i++) {
for(int j = 0; j < d; j++)
xb[d * i + j] = drand48();
xb[d * i] += i / 1000.;
}
for(int i = 0; i < nq; i++) {
for(int j = 0; j < d; j++)
xq[d * i + j] = drand48();
xq[d * i] += i / 1000.;
}
int nlist = 100;
int k = 4;
faiss::IndexFlatL2 quantizer(d); // the other index
faiss::IndexIVFFlat index(&quantizer, d, nlist, faiss::METRIC_L2);
// here we specify METRIC_L2, by default it performs inner-product search
double t0 = faiss::getmillisecs();
index.verbose = 1;
assert(!index.is_trained);
index.train(nb, xb);
double t1 = faiss::getmillisecs();
printf("train time:%.3f \n", (t1-t0)/1000.0);
assert(index.is_trained);
index.add(nb, xb); // 对底库根据聚类的中心点分桶装
double t2 = faiss::getmillisecs();
printf("add time:%.3f \n", (t2-t1)/1000.0);
{ // search xq
long *I = new long[k * nq];
float *D = new float[k * nq];
index.search(nq, xq, k, D, I);
double t3 = faiss::getmillisecs();
printf("search1 time:%.3f \n", (t3-t2)/1000.0);
printf("I=\n");
for(int i = nq - 5; i < nq; i++) {
for(int j = 0; j < k; j++)
printf("%5ld ", I[i * k + j]);
printf("\n");
}
index.nprobe = 10;
index.search(nq, xq, k, D, I);
double t4 = faiss::getmillisecs();
printf("search2 time:%.3f \n", (t4-t3)/1000.0);
printf("I=\n");
for(int i = nq - 5; i < nq; i++) {
for(int j = 0; j < k; j++)
printf("%5ld ", I[i * k + j]);
printf("\n");
}
delete [] I;
delete [] D;
}
delete [] xb;
delete [] xq;
return 0;
}
Training level-1 quantizer
Training level-1 quantizer on 100000 vectors in 64D
Training IVF residual
IndexIVF: no residual training
train time:0.190
IndexIVFFlat::add_core: added 100000 / 100000 vectors
add time:0.074
search1 time:0.044
I=
10827 10004 10049 10147
10267 10880 10330 10156
9896 10093 10361 10184
8603 9895 9946 9335
10123 11099 10876 9647
search2 time:0.202
I=
10842 10827 9938 10004
9403 10267 10880 10330
9896 10146 10093 10361
8603 10523 10582 9895
11460 10123 11099 10876
nprobe改变之后对首位搜索结果有影响。查找聚类中心的个数,默认为1个,若nprobe=nlist则等同于精确查找.
对nprobe×k个搜索结果进行重排序,找出距离最小的k个。为什么会有nprobe×k个搜索结果?因为我们不能完全信任level1的搜索结果,level1的最近邻聚类中心对应的key中并不一定包含level2的最近邻,为了保险期间,我们扩大对level1的信任范围,取最近的nprobe个聚类中心,在它们对应的子数组中分别搜索k近邻,最后再对整个结果进行重排。来源
void Clustering::train (idx_t nx, const float *x_in, Index & index) {
FAISS_THROW_IF_NOT_FMT (nx >= k,
"Number of training points (%ld) should be at least "
"as large as number of clusters (%ld)", nx, k);
double t0 = getmillisecs();
// yes it is the user's responsibility, but it may spare us some
// hard-to-debug reports.
for (size_t i = 0; i < nx * d; i++) {
FAISS_THROW_IF_NOT_MSG (finite (x_in[i]),
"input contains NaN's or Inf's"); // 输入数值检查
}
const float *x = x_in;
ScopeDeleter<float> del1;
if (nx > k * max_points_per_centroid) { // 默认分支,k=100,max_points_per_centroid=256
if (verbose)
printf("Sampling a subset of %ld / %ld for training\n",
k * max_points_per_centroid, nx);
std::vector<int> perm (nx);
rand_perm (perm.data (), nx, seed);
nx = k * max_points_per_centroid; // 100个点,每个点256个样本。总样本数
float * x_new = new float [nx * d];
for (idx_t i = 0; i < nx; i++)
memcpy (x_new + i * d, x + perm[i] * d, sizeof(x_new[0]) * d); // 随机下采样
x = x_new;
del1.set (x);
} else if (nx < k * min_points_per_centroid) {
fprintf (stderr,
"WARNING clustering %ld points to %ld centroids: "
"please provide at least %ld training points\n",
nx, k, idx_t(k) * min_points_per_centroid);
}
if (nx == k) {
if (verbose) {
printf("Number of training points (%ld) same as number of "
"clusters, just copying\n", nx);
}
// this is a corner case, just copy training set to clusters
centroids.resize (d * k);
memcpy (centroids.data(), x_in, sizeof (*x_in) * d * k);
index.reset();
index.add(k, x_in);
return;
}
if (verbose)
printf("Clustering %d points in %ldD to %ld clusters, "
"redo %d times, %d iterations\n",
int(nx), d, k, nredo, niter);
idx_t * assign = new idx_t[nx];
ScopeDeleter<idx_t> del (assign);
float * dis = new float[nx];
ScopeDeleter<float> del2(dis);
// for redo
float best_err = HUGE_VALF;
std::vector<float> best_obj;
std::vector<float> best_centroids;
// support input centroids
FAISS_THROW_IF_NOT_MSG (
centroids.size() % d == 0,
"size of provided input centroids not a multiple of dimension");
size_t n_input_centroids = centroids.size() / d; // n_input_centroids=0,输入的中心点数
if (verbose && n_input_centroids > 0) {
printf (" Using %zd centroids provided as input (%sfrozen)\n",
n_input_centroids, frozen_centroids ? "" : "not ");
}
double t_search_tot = 0;
if (verbose) {
printf(" Preprocessing in %.2f s\n",
(getmillisecs() - t0) / 1000.);
}
t0 = getmillisecs();
for (int redo = 0; redo < nredo; redo++) { // nredo=1
if (verbose && nredo > 1) {
printf("Outer iteration %d / %d\n", redo, nredo);
}
// initialize remaining centroids with random points from the dataset
centroids.resize (d * k); // 中心点的存储空间
std::vector<int> perm (nx); // 中心聚类的总样本数
rand_perm (perm.data(), nx, seed + 1 + redo * 15486557L);
for (int i = n_input_centroids; i < k ; i++) // 随机初始化聚类中心
memcpy (¢roids[i * d], x + perm[i] * d,
d * sizeof (float));
post_process_centroids ();
if (index.ntotal != 0) {
index.reset();
}
if (!index.is_trained) {
index.train (k, centroids.data()); // 没有训练
}
index.add (k, centroids.data()); // 中心点
float err = 0;
for (int i = 0; i < niter; i++) { // k-mean循环
double t0s = getmillisecs();
index.search (nx, x, 1, dis, assign); // 计算聚类样本和中心点的距离,每个聚类样本很某个中心点的最小距离/索引
InterruptCallback::check();
t_search_tot += getmillisecs() - t0s; // 时间
err = 0;
for (int j = 0; j < nx; j++) // 距离求和
err += dis[j];
obj.push_back (err);
int nsplit = km_update_centroids ( // 更新中心点
x, centroids.data(),
assign, d, k, nx, frozen_centroids ? n_input_centroids : 0);
if (verbose) {
printf (" Iteration %d (%.2f s, search %.2f s): "
"objective=%g imbalance=%.3f nsplit=%d \r",
i, (getmillisecs() - t0) / 1000.0,
t_search_tot / 1000,
err, imbalance_factor (nx, k, assign),
nsplit);
fflush (stdout);
}
post_process_centroids ();
index.reset ();
if (update_index) // update_index=false
index.train (k, centroids.data());
assert (index.ntotal == 0);
index.add (k, centroids.data()); // 将聚类中心点放入quantizer的底库
InterruptCallback::check ();
}
if (verbose) printf("\n");
if (nredo > 1) {
if (err < best_err) {
if (verbose)
printf ("Objective improved: keep new clusters\n");
best_centroids = centroids;
best_obj = obj;
best_err = err;
}
index.reset ();
}
}
if (nredo > 1) {
centroids = best_centroids;
obj = best_obj;
index.reset();
index.add(k, best_centroids.data());
}
}
FAISS_THROW_IF_NOT (is_trained);
assert (invlists);
FAISS_THROW_IF_NOT_MSG (!(maintain_direct_map && xids),
"cannot have direct map and add with ids");
const int64_t * idx;
ScopeDeleter<int64_t> del;
if (precomputed_idx) {
idx = precomputed_idx;
} else {
int64_t * idx0 = new int64_t [n];
del.set (idx0);
quantizer->assign (n, x, idx0); // 计算query和聚类中心的匹配关系
idx = idx0;
}
int64_t n_add = 0;
for (size_t i = 0; i < n; i++) {
int64_t id = xids ? xids[i] : ntotal + i;
int64_t list_no = idx [i]; // 匹配的聚类中心的索引
if (list_no < 0)
continue;
const float *xi = x + i * d;
size_t offset = invlists->add_entry (
list_no, id, (const uint8_t*) xi); // 将样本加到聚类中心
if (maintain_direct_map)
direct_map.push_back (list_no << 32 | offset);
n_add++;
}
if (verbose) {
printf("IndexIVFFlat::add_core: added %ld / %ld vectors\n",
n_add, n);
}
ntotal += n;
// 聚类中心
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]); // nprobe=1
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);
double t0 = getmillisecs();
quantizer->search (n, x, nprobe, coarse_dis.get(), idx.get()); // 计算与聚类中心的距离/匹配关系
indexIVF_stats.quantization_time += getmillisecs() - t0;
t0 = getmillisecs();
invlists->prefetch_lists (idx.get(), n * nprobe); // 未做任何操作
search_preassigned (n, x, k, idx.get(), coarse_dis.get(), // 在分桶中进行暴力搜索
distances, labels, false);
indexIVF_stats.search_time += getmillisecs() - t0;
// 分桶中暴力搜索
long nprobe = params ? params->nprobe : this->nprobe;
long max_codes = params ? params->max_codes : this->max_codes;
size_t nlistv = 0, ndis = 0, nheap = 0;
using HeapForIP = CMin<float, idx_t>;
using HeapForL2 = CMax<float, idx_t>;
bool interrupt = false;
// don't start parallel section if single query
bool do_parallel =
parallel_mode == 0 ? n > 1 :
parallel_mode == 1 ? nprobe > 1 :
nprobe * n > 1;
#pragma omp parallel if(do_parallel) reduction(+: nlistv, ndis, nheap)
{
InvertedListScanner *scanner = get_InvertedListScanner(store_pairs); // 获得倒序索引
ScopeDeleter1<InvertedListScanner> del(scanner);
/*****************************************************
* Depending on parallel_mode, there are two possible ways
* to organize the search. Here we define local functions
* that are in common between the two
******************************************************/
// intialize + reorder a result heap
auto init_result = [&](float *simi, idx_t *idxi) { // 定义一个匿名函数,参数按引用传递
if (metric_type == METRIC_INNER_PRODUCT) { // 用于simi,idxi的初始化
heap_heapify<HeapForIP> (k, simi, idxi);
} else {
heap_heapify<HeapForL2> (k, simi, idxi);
}
};
auto reorder_result = [&] (float *simi, idx_t *idxi) { // simi,idxi排序用
if (metric_type == METRIC_INNER_PRODUCT) {
heap_reorder<HeapForIP> (k, simi, idxi);
} else {
heap_reorder<HeapForL2> (k, simi, idxi);
}
};
// single list scan using the current scanner (with query
// set porperly) and storing results in simi and idxi
auto scan_one_list = [&] (idx_t key, float coarse_dis_i, //
float *simi, idx_t *idxi) {
if (key < 0) {
// not enough centroids for multiprobe
return (size_t)0;
}
FAISS_THROW_IF_NOT_FMT (key < (idx_t) nlist,
"Invalid key=%ld nlist=%ld\n", // key聚类中心点的索引
key, nlist);
size_t list_size = invlists->list_size(key); // 聚类中心点的样本数
// don't waste time on empty lists
if (list_size == 0) {
return (size_t)0;
}
scanner->set_list (key, coarse_dis_i);
nlistv++;
InvertedLists::ScopedCodes scodes (invlists, key); // 聚类中心样本的数值
std::unique_ptr<InvertedLists::ScopedIds> sids;
const Index::idx_t * ids = nullptr;
if (!store_pairs) {
sids.reset (new InvertedLists::ScopedIds (invlists, key)); // 聚类中心样本的索引
ids = sids->get();
}
nheap += scanner->scan_codes (list_size, scodes.get(),
ids, simi, idxi, k); // simi,idxi用于存放和query匹配的样本的距离和索引
return list_size;
};
/****************************************************
* Actual loops, depending on parallel_mode
****************************************************/
if (parallel_mode == 0) {
#pragma omp for
for (size_t i = 0; i < n; i++) {
if (interrupt) {
continue;
}
// loop over queries
scanner->set_query (x + i * d); // 写入query
float * simi = distances + i * k;
idx_t * idxi = labels + i * k;
init_result (simi, idxi);
long nscan = 0;
// loop over probes
for (size_t ik = 0; ik < nprobe; ik++) {
nscan += scan_one_list ( // 单样本的查询
keys [i * nprobe + ik],
coarse_dis[i * nprobe + ik],
simi, idxi
);
if (max_codes && nscan >= max_codes) {
break;
}
}
ndis += nscan;
reorder_result (simi, idxi); // 对simi,idxi排序
if (InterruptCallback::is_interrupted ()) {
interrupt = true;
}
} // parallel for