HiveQL collect_list保持顺序小记

有以下Hive表的定义:

create table topic_recommend_score (
  category_id int,
  topic_id bigint,
  score double,
  rank int
);

这张表是我们业务里话题推荐分值表的简化版本。category_id代表分类ID,topic_id是话题ID,score是评分值。rank代表每个分类下话题分值的排名,用开窗函数计算出来的:
row_number() over(partition by t.category_id order by t.score desc)

在对外提供推荐结果时,我们会将每个小组下排名前1000的话题ID取出,拼成一个逗号分隔的字符串,处理之后送入HBase供调用方查询。拼合的SQL语句如下:

select category_id,
       concat_ws(',',collect_list(cast(topic_id as string)))
from topic_recommend_score
where rank >= 1 and rank <= 1000
group by category_id;

看起来没什么问题?但实际上是错误的。输出结果中总会有一些category_id对应的列表顺序异常,比如本来排名正数与排名倒数的两批ID调换了位置,即rank变成了n-3, n-2, n-1, n, 5, 6, 7, ..., n-4, 1, 2, 3, 4

产生这个问题的根本原因自然在MapReduce,如果启动了多于一个mapper/reducer来处理数据,select出来的数据顺序就几乎肯定与原始顺序不同了。考虑把mapper数固定成1比较麻烦(见我之前写的那篇Hive调优文章),也不现实,所以要迂回地解决问题:把rank加进来再进行一次排序,拼接完之后把rank去掉。如下:

select category_id,
       regexp_replace(
         concat_ws(',',
           sort_array(
             collect_list(
               concat_ws(':',lpad(cast(rank as string),5,'0'),cast(topic_id as string))
             )
           )
         ),
       '\\d+\:','')
from topic_recommend_score
where rank >= 1 and rank <= 1000
group by category_id;

这里将rank放在了topic_id之前,用冒号分隔,然后用sort_array函数对collect_list之后的结果进行排序(只支持升序)。特别注意,rank必须要在高位补足够的0对齐,因为排序的是字符串而不是数字,如果不补0的话,按字典序排序就会变成1, 10, 11, 12, 13, 2, 3, 4...,又不对了。
将排序的结果拼起来之后,用regexp_replace函数替换掉冒号及其前面的数字,大功告成。

顺便看一下Hive源码中collect_list和collect_set函数对应的逻辑吧。

public class GenericUDAFMkCollectionEvaluator extends GenericUDAFEvaluator
    implements Serializable {
  private static final long serialVersionUID = 1l;

  enum BufferType { SET, LIST }

  // For PARTIAL1 and COMPLETE: ObjectInspectors for original data
  private transient PrimitiveObjectInspector inputOI;
  // For PARTIAL2 and FINAL: ObjectInspectors for partial aggregations (list
  // of objs)
  private transient StandardListObjectInspector loi;

  private transient ListObjectInspector internalMergeOI;

  private BufferType bufferType;

  //needed by kyro
  public GenericUDAFMkCollectionEvaluator() {
  }

  public GenericUDAFMkCollectionEvaluator(BufferType bufferType){
    this.bufferType = bufferType;
  }

  @Override
  public ObjectInspector init(Mode m, ObjectInspector[] parameters)
      throws HiveException {
    super.init(m, parameters);
    // init output object inspectors
    // The output of a partial aggregation is a list
    if (m == Mode.PARTIAL1) {
      inputOI = (PrimitiveObjectInspector) parameters[0];
      return ObjectInspectorFactory
          .getStandardListObjectInspector((PrimitiveObjectInspector) ObjectInspectorUtils
              .getStandardObjectInspector(inputOI));
    } else {
      if (!(parameters[0] instanceof ListObjectInspector)) {
        //no map aggregation.
        inputOI = (PrimitiveObjectInspector)  ObjectInspectorUtils
        .getStandardObjectInspector(parameters[0]);
        return (StandardListObjectInspector) ObjectInspectorFactory
            .getStandardListObjectInspector(inputOI);
      } else {
        internalMergeOI = (ListObjectInspector) parameters[0];
        inputOI = (PrimitiveObjectInspector) internalMergeOI.getListElementObjectInspector();
        loi = (StandardListObjectInspector) ObjectInspectorUtils.getStandardObjectInspector(internalMergeOI);
        return loi;
      }
    }
  }


  class MkArrayAggregationBuffer extends AbstractAggregationBuffer {

    private Collection container;

    public MkArrayAggregationBuffer() {
      if (bufferType == BufferType.LIST){
        container = new ArrayList();
      } else if(bufferType == BufferType.SET){
        container = new LinkedHashSet();
      } else {
        throw new RuntimeException("Buffer type unknown");
      }
    }
  }

  @Override
  public void reset(AggregationBuffer agg) throws HiveException {
    ((MkArrayAggregationBuffer) agg).container.clear();
  }

  @Override
  public AggregationBuffer getNewAggregationBuffer() throws HiveException {
    MkArrayAggregationBuffer ret = new MkArrayAggregationBuffer();
    return ret;
  }

  //mapside
  @Override
  public void iterate(AggregationBuffer agg, Object[] parameters)
      throws HiveException {
    assert (parameters.length == 1);
    Object p = parameters[0];

    if (p != null) {
      MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
      putIntoCollection(p, myagg);
    }
  }

  //mapside
  @Override
  public Object terminatePartial(AggregationBuffer agg) throws HiveException {
    MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
    List ret = new ArrayList(myagg.container.size());
    ret.addAll(myagg.container);
    return ret;
  }

  @Override
  public void merge(AggregationBuffer agg, Object partial)
      throws HiveException {
    MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
    List partialResult = (ArrayList) internalMergeOI.getList(partial);
    if (partialResult != null) {
      for(Object i : partialResult) {
        putIntoCollection(i, myagg);
      }
    }
  }

  @Override
  public Object terminate(AggregationBuffer agg) throws HiveException {
    MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
    List ret = new ArrayList(myagg.container.size());
    ret.addAll(myagg.container);
    return ret;
  }

  private void putIntoCollection(Object p, MkArrayAggregationBuffer myagg) {
    Object pCopy = ObjectInspectorUtils.copyToStandardObject(p,  this.inputOI);
    myagg.container.add(pCopy);
  }

  public BufferType getBufferType() {
    return bufferType;
  }

  public void setBufferType(BufferType bufferType) {
    this.bufferType = bufferType;
  }
}






你可能感兴趣的:(HiveQL collect_list保持顺序小记)