标准的BM25实现 for Lucene


import org.apache.lucene.index.FieldInvertState;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.TermStatistics;
import org.apache.lucene.search.similarities.Similarity;

import java.util.ArrayList;
import java.util.List;

public class AccurateBM25Similarity extends Similarity {
  private final float k1;
  private final float b;

  public AccurateBM25Similarity(float k1, float b) {
    if (!Float.isFinite(k1) || k1 < 0) {
      throw new IllegalArgumentException("illegal k1 value: " + k1 + ", must be a non-negative finite value");
    }
    if (Float.isNaN(b) || b < 0 || b > 1) {
      throw new IllegalArgumentException("illegal b value: " + b + ", must be between 0 and 1");
    }
    this.k1 = k1;
    this.b = b;
  }

  public AccurateBM25Similarity() {
    this(1.2f, 0.75f);
  }

  protected float idf(long docFreq, long docCount) {
    return (float) Math.log(1 + (docCount - docFreq + 0.5D) / (docFreq + 0.5D));
  }

  private float avgFieldLength(CollectionStatistics collectionStats) {
    return (float) (collectionStats.sumTotalTermFreq() / (double) collectionStats.docCount());
  }

  @Override
  public final long computeNorm(FieldInvertState state) {
    final int numTerms;
    if (state.getIndexOptions() == IndexOptions.DOCS && state.getIndexCreatedVersionMajor() >= 8) {
      numTerms = state.getUniqueTermCount();
    } else {
      numTerms = state.getLength();
    }
    return numTerms;
  }

  private Explanation idfExplain(CollectionStatistics collectionStats, TermStatistics termStats) {
    final long df = termStats.docFreq();
    final long docCount = collectionStats.docCount();
    final float idf = idf(df, docCount);
    return Explanation.match(idf, "idf, computed as log(1 + (N - n + 0.5) / (n + 0.5)) from:",
        Explanation.match(df, "n, number of documents containing term"),
        Explanation.match(docCount, "N, total number of documents with field"));
  }

  private Explanation idfExplain(CollectionStatistics collectionStats, TermStatistics[] termStats) {
    double idf = 0d; // sum into a double before casting into a float
    List<Explanation> details = new ArrayList<>();
    for (final TermStatistics stat : termStats) {
      Explanation idfExplain = idfExplain(collectionStats, stat);
      details.add(idfExplain);
      idf += idfExplain.getValue().floatValue();
    }
    return Explanation.match((float) idf, "idf, sum of:", details);
  }

  @Override
  public final SimScorer scorer(float boost, CollectionStatistics collectionStats, TermStatistics... termStats) {
    Explanation idf = termStats.length == 1 ? idfExplain(collectionStats, termStats[0]) : idfExplain(collectionStats, termStats);
    float avgdl = avgFieldLength(collectionStats);

    return new BM25ScorerFixed(boost, k1, b, idf, avgdl);
  }

  @Override
  public String toString() {
    return "BM25(k1=" + k1 + ",b=" + b + ")";
  }

  public final float getK1() {
    return k1;
  }

  public final float getB() {
    return b;
  }

  private static class BM25ScorerFixed extends SimScorer {

    private final float boost;
    private final float k1;
    private final float b;
    private final Explanation idf;
    private final float avgdl;
    private final float multK1_b_InvAvgdl;
    private final float multK1minusB;

    /**
     * weight (idf * boost)
     */
    private final float weight;

    BM25ScorerFixed(float boost, float k1, float b, Explanation idf, float avgdl) {
      this.boost = boost;
      this.idf = idf;
      this.avgdl = avgdl;
      this.k1 = k1;
      this.b = b;
      this.weight = boost * idf.getValue().floatValue();
      this.multK1minusB = k1 * (1 - b);
      // Normally avgdl should be >= 1, but let's use Math.max to avoid division by zero just in case
      this.multK1_b_InvAvgdl = k1 * b / Math.max(1e-10f, avgdl);
    }

    @Override
    public float score(float freq, long norm) {
      float docLen = norm;
      float wf = this.weight * freq;
      float denominator = freq + this.multK1minusB + this.multK1_b_InvAvgdl * docLen;
      return wf / denominator;
    }

    @Override
    public Explanation explain(Explanation freq, long encodedNorm) {
      List<Explanation> subs = new ArrayList<>(explainConstantFactors());
      Explanation tfExpl = explainTF(freq, encodedNorm);
      subs.add(tfExpl);
      return Explanation.match(weight * tfExpl.getValue().floatValue(),
          "score(freq=" + freq.getValue() + "), product of:", subs);
    }

    private Explanation explainTF(Explanation freq, long norm) {
      List<Explanation> subs = new ArrayList<>();
      subs.add(freq);
      subs.add(Explanation.match(k1, "k1, term saturation parameter"));
      float docLen = norm;
      subs.add(Explanation.match(b, "b, length normalization parameter"));
      subs.add(Explanation.match(docLen, "dl, length of field"));
      subs.add(Explanation.match(avgdl, "avgdl, average length of field"));
      float normValue = k1 * ((1 - b) + b * docLen / avgdl);
      return Explanation.match(
          (float) (freq.getValue().floatValue() / (freq.getValue().floatValue() + (double) normValue)),
          "tf, computed as freq / (freq + k1 * (1 - b + b * dl / avgdl)) from:", subs);
    }

    private List<Explanation> explainConstantFactors() {
      List<Explanation> subs = new ArrayList<>();
      if (boost != 1.0f) {
        subs.add(Explanation.match(boost, "boost"));
      }
      subs.add(idf);
      return subs;
    }
  }
}

用法:

import org.apache.lucene.search.similarities.BM25Similarity;

//...

SmartChineseAnalyzer analyzer = new SmartChineseAnalyzer();
IndexWriterConfig config = new IndexWriterConfig(analyzer);
config.setOpenMode(OpenMode.CREATE);
if (args.bm25Accurate) {
    config.setSimilarity(new AccurateBM25Similarity()); 
} else {
    config.setSimilarity(new BM25Similarity());
}

//...

摘自https://github.com/castorini/anserini/blob/master/src/main/java/io/anserini/search/similarity/AccurateBM25Similarity.java

你可能感兴趣的:(Java,搜索引擎)