spark基于HNSW向量检索

参考文档:https://talks.anghami.com/blazing-fast-approximate-nearest-neighbour-search-on-apache-spark-using-hnsw/
HNSW参数调优文档:https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md

spark 运行HNSW向量检索分为以下三步
1 创建HNSW索引,并存储到磁盘
2 将存储的索引分发到每个executor
3 进行向量检索
使用HHSW构建索引,并使用spark进行分布式向量检索,1200万向量构建索引40分钟,向量检索10分钟完成(时间取决于m和ef的大小,本人m=30,ef=1000,不然总是报错m或者ef太小)如m=30,ef=1000 1200万构建索引20分钟,向量检索还是10分钟。

1 创建HNSW索引

输入为spark dataset格式数据,有id和features组成,features为Array[Float]形式向量


import com.stepstone.search.hnswlib.jna.{Index, SpaceName}
import org.apache.spark.SparkFiles
import org.apache.spark.sql.{Dataset, Encoder, SparkSession}
import java.nio.file.Paths
import scala.reflect.runtime.universe.TypeTag
class annUtilsHnsw {


  /**
   * Builds an hnsw index.
   *
   * Default HNSW parameters are found to be good enough.
   *
   * HNSW index requires integer based object ids, so the builder re-indexes the original objects keys into integer
   * keys.
   *
   * For information on HNSW parameter tuning, [[https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md]]
   *
   * @param vectorSize features vector size
   * @param features objects features to build an index for
   * @param m a parameter for construction HNSW index
   * @param efConstruction a parameter for construction HNSW index
   * @tparam Key type of the object id in features objects
   * @return
   */
  def buildHnswIndex[Key : TypeTag : Encoder](spark:SparkSession,vectorSize: Int,
                                              features: Dataset[(Key, Array[Float])],
                                              m: Int = 100,
                                              efConstruction: Int = 200): HnswIndex[Key] = {

    // map objects keys to integer based index to be used in the HNSW index as it only accepts integer key
    import spark.implicits._
    val featuresReindexed = features
      .rdd.zipWithIndex().map(x=>{
      (x._1._1,x._1._2,x._2.toInt)
    })     .toDF("id", "features","index_id")
      .select("index_id", "id", "features")
      .cache()
    // collect feature vectors
    val featuresList = featuresReindexed
      .select($"index_id", $"features".cast("array"))
      .as[(Int, Array[Float])]
      .collect()

    val objectIDsMap = featuresReindexed
      .select("index_id", "id")
      .as[(Int, Key)]
      .repartition(100)

    // build index
    val index = new Index(SpaceName.COSINE, vectorSize)
    index.initialize(featuresList.length, m, efConstruction, (System.currentTimeMillis() / 1000).toInt)
    //    index.initialize(indexLength, 16, 200, (System.currentTimeMillis() / 1000).toInt)
    println("featuresList length",featuresList.length)
    // add vectors in parallel using .par
    featuresList.par.foreach {
      case (id: Int, vector: Array[Float]) =>
        index.addItem(vector, id)
    }

    // return wrapped index
    new HnswIndex(vectorSize, index, objectIDsMap)
  }


}





2 索引存储及查找

存储索引,加载索引并分发到每个executor.然后进行ANN查找


import com.stepstone.search.hnswlib.jna.{Index, SpaceName}
import org.apache.spark.SparkFiles
import org.apache.spark.sql.{Dataset, Encoder, SparkSession}
import java.nio.file.Paths
import scala.reflect.runtime.universe.TypeTag

class HnswIndex[DstKey : TypeTag : Encoder](vectorSize: Int,
                                            index: Index,
                                            objectIDsMap: Dataset[(Int, DstKey)]) {


  /**
   * Executres KNN query using an HNSW index.
   *
   * @param queryFeatures      features to generates recs for
   * @param minScoreThreshold  Minimum similarity/distance.
   * @param topK               number of top recommendations to generate per instance
   * @param ef                 HNSW search time parameter
   * @param queryNumPartitions number of partitions for query vectors
   * @return
   */
  def knnQuery[SrcKey: TypeTag : Encoder](spark: SparkSession, queryFeatures: Dataset[(SrcKey, Array[Float])],
                                          minScoreThreshold: Double,
                                          topK: Int,
                                          ef: Int,
                                          queryNumPartitions: Int = 200, indexSavePath: String, m: Int, efConstruction: Int): Dataset[(SrcKey, DstKey, Double)] = {


    import spark.implicits._
    // init tmp directory
    val indexLength = index.getLength

    val saveLocalPath = "index"
    val indexLocalLocation = Paths.get(saveLocalPath)
    val indexFileName = indexLocalLocation.getFileName.toString
    println("indexFileName", indexFileName)
    // saving index locally
    index.save(indexLocalLocation)

    println(index.getData(0).get().mkString(","))
    val saveAbsoluteLocalPath = saveLocalPath
    println("local path", indexLocalLocation.toAbsolutePath.toString)
    println("absolute path: ", saveAbsoluteLocalPath)
    // add file to spark context to be sent to running nodes
    spark.sparkContext.addFile(indexFileName, true)
    //    spark.sparkContext.addFile(indexSavePath,true)

    println("context path: ", SparkFiles.getRootDirectory + "/" + indexFileName)

    // The current interface to HNSW misses the functionality of setting the ef query time
    // parameter, but it's lower bounded by topK as per https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md#search-parameters,
    // so set a large value of k as max(ef, topK) to get high recall, then cut off after getting the nearest neighbor.
    val k = math.max(topK, ef)

    // local scope vectorSize
    val vectorSizeLocal = vectorSize

    // execute querying
    queryFeatures
      .repartition(queryNumPartitions)
      .toDF("id", "features")
      .withColumn("features", $"features".cast("array"))
      .as[(SrcKey, Array[Float])]
      .mapPartitions((it: Iterator[(SrcKey, Array[Float])]) => {
        // load index
        val index = new Index(SpaceName.COSINE, vectorSizeLocal)

        index.initialize(indexLength, m, efConstruction, (System.currentTimeMillis() / 1000).toInt)
        index.load(Paths.get(SparkFiles.getRootDirectory + "/" + indexFileName), indexLength)


        it.flatMap(x => {
          val idx = x._1
          val vector = x._2
          val queryTuple = index.knnQuery(vector, k)
          val result = queryTuple.getIds
            //            queryTuple.getLabels
            .zip(queryTuple.getCoefficients)
            .map(qt => (idx, qt._1, 1.0 - qt._2.toDouble))
            .filter(_._3 >= minScoreThreshold)
            .sortBy(_._3)
            .reverse
            .slice(0, topK)
          result
        })
      })

      .as[(Int, Int, Double)]
      .toDF("src_id", "index_id", "score")
      .join(objectIDsMap.toDF("index_id", "dst_id"), Seq("index_id"))
      .select("src_id", "dst_id", "score")
      .repartition(400)
      .as[(SrcKey, DstKey, Double)]

  }
}

3 word2vec向量检索实例

  • 训练word2vec模型
  • 将模型的向量取出,调用上面buildHnswIndex 构建索引
  • 分布式进行knnQuery 向量检索


import org.apache.spark.ml.feature.Word2VecModel
import org.apache.spark.ml.linalg.DenseVector

object exampleWord2Vec {
  def main(args: Array[String]): Unit = {
   val spark = SparkSession.builder().getOrCreate()
    
    val GraphInputModel =  "graph/model/word2vecmodel"
    val indexPath =  "graph/model/index"
    spark.udf.register("denseVec2Array",(vec:DenseVector ) => vec.toArray.map(_.toFloat))
    spark.udf.register("vectorSplit",(a:String)=>(a.split(',').map(_.toFloat)))
    import spark.implicits._
    val word2vec = Word2VecModel.load(GraphInputModel )
    println(word2vec .getVectors.schema)
    word2vec .getVectors.show(10)
    println(word2vec .getVectors.count())
    val itemEmbeddings = word2vec .getVectors.selectExpr("cast(word as Int) as word", "denseVec2Array(vector) features")
      .as[(Int,Array[Float])]
    itemEmbeddings.show()
    println(itemEmbeddings.schema)
    val vectorsize=itemEmbeddings.take(1)(0)._2.length

    val hnswIndex = new annUtilsHnsw().buildHnswIndex(spark, vectorsize, itemEmbeddings, 20)
    val queryDF=hnswIndex.knnQuery[Int](spark,itemEmbeddings.limit(20),0.3,20,200,160,indexPath,20,200)
    queryDF .write.mode("overwrite").save(savePathMl + "graph/muiscEmbedding")
  }

}

4 HNSW pom依赖文件

hnswlib-jna

        
            com.stepstone.search.hnswlib.jna
            hnswlib-jna
            1.4.2
        

你可能感兴趣的:(推荐系统,spark,大数据,ANN,HNSW,向量检索)