faiss是为稠密向量提供高效相似度搜索和聚类的框架。由Facebook AI Research研发。
详见github https://github.com/facebookresearch/faiss
faiss常用的两个相似度搜索是L2欧氏距离搜索和余弦距离搜索(注意不是余弦相似度)
简单的使用流程:
import faiss
index = faiss.IndexFlatL2(d) # 建立L2索引,d是向量维度
index = faiss.IndexFlatIP(d) # 建立Inner product索引
index.add(train) # 添加矩阵
D,I = index.search(test, k) # (D.shape = test.shape[0] * k, I同理)
上述代码实现了对于test向量(也可以是矩阵)索引train中L2距离最近的k个向量,返回其具体distance和索引index
IndexFlatIP()函数实现的是余弦距离的计算也就是 x y t xy^t xyt,显然,当向量范数不为一的情况下不能等同于余弦相似度 x y t ∣ ∣ x ∣ ∣ ∣ ∣ y ∣ ∣ \frac{xy^t}{||x||||y||} ∣∣x∣∣∣∣y∣∣xyt
在许多论文特别是需要计算索引的时候,相似度往往选择余弦相似度,因此在这里记录一下如何实现:
train = np.array([[1.0,1.0],[2.5,0],[0,2.5],[1.5,0.5]]).astype('float32') # 注意 必须为float32类型
test = np.array([[0.5,0.5]]).astype('float32')
print('L2 norm of train', np.linalg.norm(train[0]))
print('L2 norm of test', np.linalg.norm(test))
faiss.normalize_L2(train)
faiss.normalize_L2(test)
print('L2 norm of train', np.linalg.norm(train[0]))
print('L2 norm of test', np.linalg.norm(test))
L2 norm of train 1.4142135
L2 norm of test 0.70710677
L2 norm of train 0.99999994
L2 norm of test 0.99999994
对于被索引矩阵和查询向量,都先经过L2归一化,(normlize_L2函数)
定义索引函数
def KNN_cos(train_set, test_set, n_neighbours):
index = faiss.IndexFlatIP(train_set.shape[1])
index.add(train_set)
D, I = index.search(test_set, n_neighbours)
return D,I
测试
Distance, Index = KNN_cos(train, test,3)
Distance: (array([[0.99999994, 0.8944272 , 0.70710677]], dtype=float32),
Index: array([[0, 3, 2]]))
在github上看到有人给出这样的解决方法
num_vectors = 1000000
vector_dim = 1024
vectors = np.random.rand(num_vectors, vector_dim)
#sample index code
quantizer = faiss.IndexFlatIP(1024)
index = faiss.IndexIVFFlat(quantizer, vector_dim, int(np.sqrt(num_vectors)), faiss.METRIC_INNER_PRODUCT) # 利用IVFFLat提升效率
train_vectors = vectors[:int(num_vectors/2)].copy()
faiss.normalize_L2(train_vectors)
index.train(train_vectors)
faiss.normalize_L2(vectors)
index.add(vectors)
#index creation done
#let's search
query_vector = np.random.rand(10, 1024)
faiss.normalize_L2(query_vector)
D, I = index.search(query_vector, 100)
print(D)
其实这里做了个提速:利用IVFlat先进行聚类再索引,提升效率,详见可以看官方源码
关于faiss库进行索引查询还有很多操作,特别是对于海量数据,合理的利用faiss可以极大提升效率。