Elasticsearch plugin开发 之 自定义payload_score query

当需要将term的权重存储到索引中时,需要保存成payload的格式:

源代码:https://github.com/limingnihao/elasticsearch-reference/tree/master/Examples
官方文档:https://www.elastic.co/guide/en/elasticsearch/reference/7.10/analysis-delimited-payload-tokenfilter.html

类似于:

the|0 brown|3 fox|4 is|0 quick|10

查询的时候,如果需要用到保存好的value,则需要lucene 的PayloadScoreQuery或者PayloadCheckQuery。

PayloadScoreQuery:

首先查看下lucene的PayloadScoreQuery的构造方法:


  /**
   * Creates a new PayloadScoreQuery
   * @param wrappedQuery the query to wrap
   * @param function a PayloadFunction to use to modify the scores
   * @param decoder a PayloadDecoder to convert payloads into float values
   * @param includeSpanScore include both span score and payload score in the scoring algorithm
   */
  public PayloadScoreQuery(SpanQuery wrappedQuery, PayloadFunction function, PayloadDecoder decoder, boolean includeSpanScore) {
    this.wrappedQuery = Objects.requireNonNull(wrappedQuery);
    this.function = Objects.requireNonNull(function);
    this.decoder = Objects.requireNonNull(decoder);
    this.includeSpanScore = includeSpanScore;
  }

可以发现,需要构造4个参数:

  • SpanQuery wrappedQuery。进行召回的query,必须是spanQuery
  • PayloadFunction function。当命中多个term时,得分的计算规则,max、min、sum、
  • PayloadDecoder decoder。保存的value的解码方式。int或float类型
  • boolean includeSpanScore。是否使用保存的分数。

下面开始开发,需要构建2个类一个是plugin、一个是builder

PayloadScoreQParserPlugin

用于构造Builder的

public class PayloadScoreQParserPlugin extends Plugin implements SearchPlugin {

    @Override
    public List> getQueries() {
        return Collections.singletonList(
            new QuerySpec<>(PayloadScoreQueryBuilder.NAME, PayloadScoreQueryBuilder::new, PayloadScoreQueryBuilder::fromXContent)
        );
    }
}

PayloadScoreQueryBuilder

首先解析参数的fromXContent方法:

主要用于解析我们自定义的参数:query、func、calc(后续扩展权重交叉计算)、includeSpanScore

public static QueryBuilder fromXContent(XContentParser parser) throws IOException {
    String currentFieldName = null;
    XContentParser.Token token;
    QueryBuilder iqb = null;

    String func = null;
    String calc = null;
    boolean includeSpanScore = false;
    while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
        if (token == XContentParser.Token.FIELD_NAME) {
            currentFieldName = parser.currentName();
        } else if (token == XContentParser.Token.START_OBJECT) {
            if (QUERY_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                iqb = parseInnerQueryBuilder(parser);
            } else {
                throw new ParsingException(parser.getTokenLocation(),
                    "[" + NAME + "] query does not support [" + currentFieldName + "]");
            }
        } else if (token.isValue()) {
            if (FUNC_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                func = parser.text();
            } else if (CALC_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                calc = parser.text();
            } else if (INCLUDE_SPAN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                includeSpanScore = parser.booleanValue();
            } else {
                throw new ParsingException(parser.getTokenLocation(),
                    "[" + NAME + "] query does not support [" + currentFieldName + "]");
            }
        }
    }
    return new PayloadScoreQueryBuilder(iqb, func, calc, includeSpanScore);
}

构造PayloadScoreQuery的doToQuery方法:

主要是将lucene的PayloadScoreQuery类需要的4个参数构造出来:

protected Query doToQuery(SearchExecutionContext context) throws IOException {
    // query  parse
    SpanQuery spanQuery = null;
    try {
        spanQuery = (SpanQuery) query.toQuery(context);
    } catch (IOException e) {
        throw new IllegalArgumentException(e);
    }

    if (spanQuery == null) {
        throw new IllegalArgumentException("SpanQuery is null");
    }

    PayloadFunction payloadFunction = PayloadUtils.getPayloadFunction(this.func);
    if (payloadFunction == null) {
        throw new IllegalArgumentException("Unknown payload function: " + func);
    }
    PayloadDecoder payloadDecoder = PayloadUtils.getPayloadDecoder("float");

    return new PayloadScoreQuery(spanQuery, payloadFunction, payloadDecoder, this.includeSpanScore);
}

PayloadScoreQueryBuilder完整代码

package org.elasticsearch.plugins.payload;

import org.apache.lucene.queries.payloads.PayloadDecoder;
import org.apache.lucene.queries.payloads.PayloadFunction;
import org.apache.lucene.queries.payloads.PayloadScoreQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.spans.SpanQuery;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.*;

import java.io.IOException;
import java.util.Objects;

public class PayloadScoreQueryBuilder extends AbstractQueryBuilder {
    public static final String NAME = "payload_score";

    private static final ParseField QUERY_FIELD = new ParseField("query");
    private static final ParseField FUNC_FIELD = new ParseField("func");
    private static final ParseField CALC_FIELD = new ParseField("calc");
    private static final ParseField INCLUDE_SPAN_SCORE_FIELD = new ParseField("includeSpanScore");

    private final QueryBuilder query;
    private final String func;
    private final String calc;
    private final boolean includeSpanScore;

    public PayloadScoreQueryBuilder(QueryBuilder query, String func, String calc, boolean includeSpanScore) {
        this.query = requireValue(query, "[" + NAME + "] requires '" + QUERY_FIELD.getPreferredName() + "' field");
        this.func = func;
        this.calc = calc;
        this.includeSpanScore = includeSpanScore;
    }

    public PayloadScoreQueryBuilder(StreamInput in) throws IOException {
        super(in);
        this.query = in.readNamedWriteable(QueryBuilder.class);
        this.func = in.readString();
        this.calc = in.readString();
        this.includeSpanScore = in.readBoolean();
    }

    @Override
    protected void doWriteTo(StreamOutput out) throws IOException {
        out.writeNamedWriteable(query);
        out.writeString(this.func);
        out.writeString(this.calc);
        out.writeBoolean(this.includeSpanScore);
    }

    @Override
    protected void doXContent(XContentBuilder builder, Params params) throws IOException {
        builder.startObject(NAME);
        builder.field(QUERY_FIELD.getPreferredName());
        query.toXContent(builder, params);

        builder.field(FUNC_FIELD.getPreferredName(), this.func);
        builder.field(CALC_FIELD.getPreferredName(), this.calc);
        builder.field(INCLUDE_SPAN_SCORE_FIELD.getPreferredName(), this.includeSpanScore);
        printBoostAndQueryName(builder);
        builder.endObject();
    }

    public static QueryBuilder fromXContent(XContentParser parser) throws IOException {
        String currentFieldName = null;
        XContentParser.Token token;
        QueryBuilder iqb = null;

        String func = null;
        String calc = null;
        boolean includeSpanScore = false;
        while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
            if (token == XContentParser.Token.FIELD_NAME) {
                currentFieldName = parser.currentName();
            } else if (token == XContentParser.Token.START_OBJECT) {
                if (QUERY_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    iqb = parseInnerQueryBuilder(parser);
                } else {
                    throw new ParsingException(parser.getTokenLocation(),
                        "[" + NAME + "] query does not support [" + currentFieldName + "]");
                }
            } else if (token.isValue()) {
                if (FUNC_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    func = parser.text();
                } else if (CALC_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    calc = parser.text();
                } else if (INCLUDE_SPAN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    includeSpanScore = parser.booleanValue();
                } else {
                    throw new ParsingException(parser.getTokenLocation(),
                        "[" + NAME + "] query does not support [" + currentFieldName + "]");
                }
            }
        }
        return new PayloadScoreQueryBuilder(iqb, func, calc, includeSpanScore);
    }

    @Override
protected Query doToQuery(SearchExecutionContext context) throws IOException {
    // query  parse
    SpanQuery spanQuery = null;
    try {
        spanQuery = (SpanQuery) query.toQuery(context);
    } catch (IOException e) {
        throw new IllegalArgumentException(e);
    }

    if (spanQuery == null) {
        throw new IllegalArgumentException("SpanQuery is null");
    }

    PayloadFunction payloadFunction = PayloadUtils.getPayloadFunction(this.func);
    if (payloadFunction == null) {
        throw new IllegalArgumentException("Unknown payload function: " + func);
    }
    PayloadDecoder payloadDecoder = PayloadUtils.getPayloadDecoder("float");

    return new PayloadScoreQuery(spanQuery, payloadFunction, payloadDecoder, this.includeSpanScore);
}

    @Override
    protected boolean doEquals(PayloadScoreQueryBuilder that) {
        return Objects.equals(query, that.query)
            && Objects.equals(func, that.func)
            && Objects.equals(calc, that.calc)
            && Objects.equals(includeSpanScore, that.includeSpanScore);
    }

    @Override
    protected int doHashCode() {
        return Objects.hash(query, func, calc, includeSpanScore);
    }

    @Override
    public String getWriteableName() {
        return NAME;
    }

}

执行示例:

POST http://127.0.0.1:9200/position/_search
{
    "query": {
        "payload_score": {
            "func": "sum",
            "calc": "sum",
            "includeSpanScore": "false",
            "query": {
                "span_or": {
                    "clauses": [
                        {
                            "span_term": {
                                "FIELD": "test"
                            }
                        }
                    ]
                }
            }
        }
    }
}

你可能感兴趣的:(Elasticsearch plugin开发 之 自定义payload_score query)