tensorflow API使用笔记 Bucketize

tensorflow分桶API,有好几个接口,其中带boundaries的接口C++实现如下:

template 
struct BucketizeFunctor {
  // PRECONDITION: boundaries_vector must be sorted.
  static Status Compute(OpKernelContext* context,
                        const typename TTypes::ConstTensor& input,
                        const std::vector& boundaries_vector,
                        typename TTypes::Tensor& output) {
    const int N = input.size();
    for (int i = 0; i < N; i++) {
      auto first_bigger_it = std::upper_bound(
          boundaries_vector.begin(), boundaries_vector.end(), input(i));
      output(i) = first_bigger_it - boundaries_vector.begin();
    }

    return Status::OK();
  }
};
  • 输入:input tensor和boundaries_vector
  • 输出:output tensor

使用stl的upper_bound算法查找第一个大于输入值的bound,然后返回这个bound的偏移索引。

所以这里,用户需要指定一个boundary,并且boundary是不变的。

你可能感兴趣的:(tensorflow API使用笔记 Bucketize)