TensorFlow的自定义算子实现

在学习中涉及到了TensorFlow的自定义算子实现,现将整个工程中的一些思考写下来,有问题的部分也请大家指正!!!

OP和Kernel是TensorFlow框架最重要的两个概念,OP类似于函数声明,Kernel类似于实现。要注意以下四个方面:一是所有Op包含注册和实现两部分;二是OpKernel类(./core/framework/op_kernel.h)是所有Op类的基类;三是所有Op类的实现需要overide抽象基函数void Compute(OpKernelContext* context),实现自身Op功能;四是所有Op操作的属性定义和描述符合protobuf协议。

一、自定义算子实现基本流程

1. OP注册

在一个C++文件中注册新Op,其注册与实现相互独立,该文件指定自定义算子的输入输出、参数,命名采用驼峰命名法。

/**
* ./tensorflow/core/framework/op.h
* #define REGISTER_OP(name) REGISTER_OP_UNIQ_HELPER(__COUNTER__, name)
* #define REGISTER_OP_UNIQ_HELPER(ctr, name) REGISTER_OP_UNIQ(ctr, name)
* #define REGISTER_OP_UNIQ(ctr, name)                                          \
*   static ::tensorflow::register_op::OpDefBuilderReceiver register_op##ctr    \
*       TF_ATTRIBUTE_UNUSED =                                                  \
*           ::tensorflow::register_op::OpDefBuilderWrapper(name)
* REGISTER_OP本质是创建了一个OpDefBuilderReceiver对象,
* 并将Attr,Input,Output等保存在OpDefBuilder对象中。
*/
REGISTER_OP("myFunc") //: ,通过context参数访问这个属性
    .Input("in1: int32")
    .Input("in2: int32")
    .Output("out: int32")
    .Attr("Para1: int")
    .Attr("Para2: int")
    .SetShapeFn([](InferenceContext *c){return Status::OK();})

上述表示:注册名为myFunc的算子,输入in1和in2,类型为int32;输出为out,类型为int32;参数为Para1和Para2,类型为int,ShapeFn用于shape推断。

也可以在注册时赋予默认值,默认值支持的语法将在最终GraphDef定义的pb表示中被使用。

2. Kernel实现

/**
* tensorflow/core/framework/op_kernel.h
* class OpKernel {
*   public:
*    explicit OpKernel(OpKernelConstruction* context);
*   
*    OpKernel(OpKernelConstruction* context, bool is_deferred);
*   
*    OpKernel(OpKernelConstruction* context, NodeDef&& custom_def,
*             bool is_deferred);
*    ...
*      TF_DISALLOW_COPY_AND_ASSIGN(OpKernel);
* };
*/
class myFuncOp: public OpKernel{ //创建一个类,继承OpKernel类
    public:
        //创建构造函数并显示调用OpKernel(context)
        explicit myFuncOp(OpKernelConstruction* context):OpKernel(context)
        {
            //参数获取
            OP_REQUIRES_OK(context,context->GetAttr("attr_name",&attr_name));
        }
        void Compute(OpKernelContext* context) override //重写OpKernel类的Compute方法
        {
            //输入tensor
            Tensor* in1 = const_cast(&context->input(0));  
            Tensor* in2 = const_cast(&context->input(1));
            //创建一个输出, 使用context->allocate_ouput()分配空间
            Tensor* out = NULL;
            TensorShape out_shape(...);
            OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &out));
            ...
            //算子行为的具体实现
            ...
        }
}

3. 算子的Kernel注册

/**
* #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
*   REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
* #define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
*   REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)
* #define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...)        \
*   constexpr bool should_register_##ctr##__flag =                      \
*       SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__);                        \
*   static ::tensorflow::kernel_factory::OpKernelRegistrar              \
*       registrar__body__##ctr##__object(                               \
*           should_register_##ctr##__flag                               \
*               ? ::tensorflow::register_kernel::kernel_builder.Build() \
*               : nullptr,                                              \
*           #__VA_ARGS__,                                               \
*           [](::tensorflow::OpKernelConstruction* context)             \
*               -> ::tensorflow::OpKernel* {                            \
*             return new __VA_ARGS__(context);                          \
*           });
* REGISTER_KERNEL_BUILDER实质是创建一个名称唯一的类型为OpKernelRegistrar的全局静态变量
* class OpKernelRegistrar {
*     public:
*     OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
*                       std::unique_ptr factory) {
*       if (kernel_def != nullptr) {
*         InitInternal(kernel_def, kernel_class_name, std::move(factory));
*       }
*     }
*     OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
*                       OpKernel* (*create_fn)(OpKernelConstruction*)) {
*       if (kernel_def != nullptr) {
*         InitInternal(kernel_def, kernel_class_name,
*                      absl::make_unique(create_fn));
*       }
*     }
* }
* OpKernelRegistrar的构造需要三个被包装到KernelRegistration这个结构体里的参数,并作为Kernel注册表的值:
* 第一个是KernelDef,第二个是定义Kernel的类名,第三个是创建kernel对象的函数;
* 首先调用KernelDefBuilder的Build函数获得对应的KernelDef;
* 然后获取用于创建这个Kernel的C++类名称;
* 最后包装一个factory函数用来接收传进来的OpKernelConstruction*,创建对应的Kernel类对象,并返回其指针。
*/
REGISTER_KERNEL_BUILDER(Name("myFunc").Device(DEVICE_CPU), myFuncOp);

二、示例(基于《智能计算系统》实验7-1)

在NMS实现之后,需要将其集成到TF框架中重编译,整个过程涉及接口封装与算子集成。

1. PluginOP封装

利用CNML PluginOP封装出便于用户使用的CNPlugin接口(该过程已实现)。

//plugin_yolov3_detection_output_op.cc
cnmlStatus_t cnmlCreatePluginYolov3DetectionOutputOp(//算子创建、参数声明及初始化...
    cnmlBaseOp_t *op,
    cnmlPluginYolov3DetectionOutputOpParam_t param,
    cnmlTensor_t *yolov3_input_tensors,
    cnmlTensor_t *yolov3_output_tensors){...}

cnmlStatus_t cnmlComputePluginYolov3DetectionOutputOpForward(...)//调用cnmlComputePluginOpForward完成计算
{
    ...
    cnmlComputePluginOpForward_V3(...);//cnmlComputePluginOpForward_V4(...)
    ...
}

2. Lib层封装

直接封装CNML和CNPlugin算子,结果供算子的DLP实现函数调用,该封装目的是将高层调用与底层实现有效隔离。

//mlu_lib_ops.cc & mlu_lib_ops.h
tensorflow::Status CreateYolov3DetectionOutputOp(...)
{
    CNML_RETURN_STATUS(cnmlCreatePluginYolov3DetectionOutputOp(op, param, input_tensors, output_tensors));
}

tensorflow::Status ComputeYolov3DetectionOutputOp(...)//
{ 
    ...
    cnmlComputePluginYolov3DetectionOutputOpForward(op, inputs, input_num, outputs, output_num, &compute_forw_param, queue);
}

3. 算子的DLP实现

//mlu_ops.h 算子类声明
struct MLUYolov3DetectionOutputOpParam{//数据成员声明
    ...
    MLUYolov3DetectionOutputOpParam(...): ...{}
}

/**
* 类声明,继承自MLUBaseOpWrapper
* CreateMLUOp(inputs, outputs, param)
* Compute(const std::vector &inputs, const std::vector &outputs, cnrtQueue_t queue) override
*/
DECLARE_OP_CLASS(MLUYolov3DetectionOutput);

//yolov3detectionoutput.cc 实现
Status MLUYolov3DetectionOutput::CreateMLUOp(std::vector &inputs, std::vector &outputs, void *param){
    //定义输入输出tensor
    ...
    //参量初始化
    ...
    //调用cnmlCreatePluginYolov3DetectionOutputOpParam
    //调用CreateYolov3DetectionOutputOp
    ...
}

Status MLUYolov3DetectionOutput::Compute(const std::vector &inputs, const std::vector &outputs, cnrtQueue_t queue)
{
    //变量获取
    ...
    //调用ComputeYolov3DetectionOutputOp
    ...
}

4. MLU算子实例化

运行时会MLU自动将算子与运行时队列绑定并下发执行。

//mlu_stream.h
Status Yolov3DetectionOutput(OpKernelContext* ctx,
                    Tensor* tensor_input0,
                    Tensor* tensor_input1,
                    Tensor* tensor_input2,
                    ...
                    Tensor* output1,
                    Tensor* output2){
    //实例化MLUYolov3DetectionOutputOpParam
    ops::MLUYolov3DetectionOutputOpParam op_param(...);
    //调用MLUYolov3Detectionutput,CommonOpImpl接口用于处理输入输出并创建OP
    return CommonOpImpl(
        ctx,
        {tensor_input0, tensor_input1, tensor_input2},
        {output1, output2},
        static_cast(&op_param));
}

5. Kernel实现

//yolov3_detection_output_op_mlu.h
class MLUYolov3DetectionOutputOp: public MLUOpKernel{//创建继承自MLUOpKernel的类
    public:
        //创建构造函数并显示调用MLUOpKernel(context)
        explicit MLUYolov3DetectionOutputOp(OpKernelConstruction* context):MLUOpKernel(context){
            //参数获取
            OP_REQUIRES_OK(context,context->GetAttr("Attr",&Attr_));
            ...
        }
        void ComputeOnMLU(OpKernelContext* context) override {
            ...

            //将输入tensor从context中取出
            Tensor* input0 = const_cast(&context->input(0));  
            Tensor* input1 = const_cast(&context->input(1));
            Tensor* input2 = const_cast(&context->input(2));
            ...
            
            //创建输出, 使用context->allocate_ouput()给它分配空间,并进行形状推断
            Tensor* output; 
            Tensor* buffer;
            TensorShape tf_output_shape {...};
            TensorShape tf_buffer_shape {...};
            OP_REQUIRES_OK(context, context->allocate_output(0, tf_output_shape, &output));
            OP_REQUIRES_OK(context, context->allocate_output(0, tf_buffer_shape, &buffer));

            //调用自定义算子
            OP_REQUIRES_OK(context,stream->Yolov3DetectionOutput(...));
    }
    //参数声明
    private:
    int batchNum_;
    int inputNum_;
    int classNum_;
    int maskGroupNum_;
    int maxBoxNum_;
    int netw_;
    int neth_;
    float confidence_thresh_;
    float nms_thresh_;
    std::vector inputWs_;
    std::vector inputHs_;
    std::vector biases_;
};

在进行形状推断时,需要注意以下:

//cnplugin.h
/*!
 *  @brief A function.
 *
 *  This function creates PluginYolov3DetectionOutputOp with proper param,
 *  input, and output tensors.
 *
 *  PluginYolov3DetectionOutputOp takes in feature maps and network
 *  parameters and computes valid bounding boxes based on two thresholds
 *  you have chosen.
 *
 *  **Reference:**
 *    This implementation is based on the project on ``github/pjreddie/darknet`` .
 *
 *  **Formula:** This op contains two steps:
 *
 *    1. DecodeAllBBoxes.
 *
 *       Convert input feature maps into real ojectness score and coordinates.
 *    for inputIdx in (0, inputNum - 1)
 *
 *       obj = sigmoid(obj_feature);
 *       x   = (x_offset + sigmoid(x_feature)) / inputWs[inputIdx]
 *       y   = (y_offset + sigmoid(y_feature)) / inputHs[inputIdx]
 *       w   = (w_biases * exp(w_feature)) / netw
 *       h   = (h_biases * exp(h_feature)) / neth
 *       Obj, x_feature, y_feature, w_feature, h_feature are data from input feature maps.
 *       x_offset, y_offset are the coordinates of the grid cell in the feature map.
 *       w_offset, h_biases are the shape of the anchor box.
 *
 *    2. Non-maximum Suppression
 *       For each class of data, compute IOU score for every pair of bounding boxes.
 *       If IOU score exceeds the IOU threshold, keep the box with larger score.
 *       x1 = x - w / 2
 *       y1 = y - y / 2
 *       x2 = x + w / 2
 *       y2 = y + y / 2
 *       for classIdx in (0, classNum - 1)
 *        conf = obj * probability[classIdx]
 *        max, maxIdx = findMaxValueAndIndex(conf)
 *        if (max >= confidence_thresh)
 *          for boxIdx in (0, boxNum - 1)
 *            iou = computeIOU(coord_maxIdx, coord_boxIdx)  // where "coords" means x1,y1,x2,y2
 *            if (iou < nms_thresh)
 *              keep coords and conf for boxIdx
 *
 *  **DataType:**
 *    Support only half(float16) type for both input and output tensors.
 *
 *  **Performance Optimization:**
 *    The performance of detection layer depends on both the data size and the value.
 *    However, this op achieves relatively better performance when
 *    all of the following conditions are met:
 *    - inputH/Ws are 64-aligned(unit in number of data).
 *    - (5 + classNum) is 64-aligned(unit in number of data).
 *    The bigger the remainder of the value of param divided by 64, the better performance the op will achieve.
 *  Supports both MLU220 and MLU270.
 *
 *  @param[out]  op
 *    Output. A pointer to the base operator address.
 *  @param[in]  param
 *    Input. A PluginYolov3DetectionOutput parameter struct pointer.
 *  @param[in]  yolov3_input_tensors
 *    Input. An array of four-demensional cnmlTensors with a shape of
 *           [batchNum, (5 + classNum) * numMaskGroup, inputH, inputW](NCHW).
 *           Support only FLOAT16 dataType currently.
 *  @param[in]  outputs
 *    Input. An array of four-demensional cnmlTensors with a shape of
 *           [batchNum, 64 + 7 * numMaxBox, 1, 1](NCHW).
 *           Support only FLOAT16 dataType currently.
 *           The first two numbers of each batch store the number of
 *           detected boxes. The data for each box starts from the 65th number,
 *           with an order of [batchId, classId, score, x1, y1, x2, y2], where
 *           (x1, y1) and (x2, y2) are the coordinates of top-left and bottom-
 *           -right points accordingly.
 *  @retval CNML_STATUS_SUCCESS
 *    The function ends normally
 *  @retval CNML_STATUS_INVALIDPARAM
 *    At least one of the following conditions is not met:
 *    - Base op pointer is nullptr
 *    - Param is nullptr or not initialized
 *    - Input / output tensor desps is nullptr or inconsistent with param.
 */
cnmlStatus_t cnmlCreatePluginYolov3DetectionOutputOp(
    cnmlBaseOp_t *op,
    cnmlPluginYolov3DetectionOutputOpParam_t param,
    cnmlTensor_t *yolov3_input_tensors,
    cnmlTensor_t *yolov3_output_tensors);

定义cnmlCreatePluginYolov3DetectionOutputOp时,对输出张量shape进行了明确,为[batchNum, 64 + 7 * numMaxBox, 1, 1]。

6. 注册

//yolov3_detection_output_op.cc  Kernel注册
REGISTER_KERNEL_BUILDER(                \
      Name("Yolov3DetectionOutput")     \
      .Device(DEVICE_MLU)               \
      .TypeConstraint("T"),          \
      MLUYolov3DetectionOutputOp);

//image_ops.cc  OP注册
REGISTER_OP("Yolov3DetectionOutput")
    .Output("predicts: T")
    .Input("input0: T")
    .Input("input1: T")
    .Input("input2: T")
    .Attr("batchNum:int")
    .Attr("inputNum:int")
    .Attr("classNum:int")
    .Attr("maskGroupNum:int")
    .Attr("maxBoxNum:int")
    .Attr("netw:int")
    .Attr("neth:int")
    .Attr("confidence_thresh:float")
    .Attr("nms_thresh:float")
    .Attr("inputWs: list(int)")
    .Attr("inputHs: list(int)")
    .Attr("biases: list(float)")
    .Attr("T: type")
    .SetShapeFn([](InferenceContext *c){return SetOutputForYolov3DetectionOutput(c);
    });

在OP注册时,其涉及到的输入输出及参量和.pbtxt中node一一对应。

//./cnplugin.h
/*!
 *  @brief A function.
 *  This function creates a PluginYolov3DetectionOutputOp param object with
 *  the pointer and parameters provided by user.
 *  **Supports MLU220/MLU270**
 *  @param[out] param
 *    Output. The returning param descriptor.
 *  @param[in] batchNum
 *    Input. The number of input batches.
 *           No default value, a valid batchNum must be in the range of [1, inf).
 *  @param[in] inputNum
 *    Input. The number of input tensors.
 *           No default value, a valid inputNum must be in the range of [1, 7].
 *  @param[in] classNum
 *    Input. The number of input classes.
 *           No default value, a valid classNum must be in the range of [1, 4096].
 *  @param[in] maskGroupNum
 *    Input. The number of anchors used by every input tensors.
 *           No default value, a valid maskGroupNum must be in the range of [1, inf].
 *  @param[in] maxBoxNum
 *    Input. The largest possible number of output boxes.
 *           Default value is 1024, a valid maxBoxNum must be in the range of [1, inf].
 *  @param[in] netw
 *    Input. Width of input image of backbone network.
 *           No default value, a valid netw must be in the range of [1, inf).
 *  @param[in] neth
 *    Input. Height of input image of backbone network.
 *           No default value, a valid neth must be in the range of [1, inf).
 *  @param[in] confidence_thresh
 *    Input. Confidence threshold.
 *           No default value, a valid confidence_thresh must be in the range of [0, 1].
 *  @param[in] nms_thresh.
 *    Input. IOU threshold used in NMS function.
 *           No default value, a valid nms_thresh must be in the range of [0, 1].
 *  @param[in] core_version
 *    Input. Supported core version.
 *           No default value, a valid core_version must be either MLU220 or MLU270.
 *  @param[in] inputWs
 *    Input. Width of every input tensor. Must have the same order as inputHs
 *           No default value, the number of valid elements must be equal with inputNum.
 *  @param[in] inputHs
 *    Input. Height of every input tensor. Must have the same order as inputWs
 *           No default value, the number of valid elements must be equal with inputNum.
 *  @param[in] biases
 *    Input. Anchors of every input tensor.
 *           No default value. The number of valid elements must be equal with 2 x inputNum x maskGroupNum.
 *           The order of data from high to low, is [N(1) H(inputNum) W(maskGroupNum) C(2)]. For example:
 *           Width of anchor for mask0 input0, Height of anchor for mask0 input0,
 *           Width of anchor for mask1 input0, Height of anchor for mask1 input0,
 *           ...
 *           Width of anchor for maskN input0, Height of anchor for maskN input0,
 *           Width of anchor for mask0 input1, Height of anchor for mask0 input1,
 *           ......
 *  @retval CNML_STATUS_SUCCESS
 *    The object was set successfully.
 *  @retval CNML_STATUS_INVALIDPARAM
 *    The inputH/Ws ptr is nullptr or input param is invalid.
 */
cnmlStatus_t cnmlCreatePluginYolov3DetectionOutputOpParam(
    cnmlPluginYolov3DetectionOutputOpParam_t *param,
    int batchNum,
    int inputNum,
    int classNum,
    int maskGroupNum,
    int maxBoxNum,
    int netw,
    int neth,
    float confidence_thresh,
    float nms_thresh,
    cnmlCoreVersion_t core_version,
    int *inputWs,
    int *inputHs,
    float *biases);

在./cnplugin.h里定义了cnmlCreatePluginYolov3DetectionOutputOpParam,注释对每个参数含义进行了说明。对涉及到的参量,需要给定默认值,可以在OP注册时给定,也可以在添加node时给定。

其参数由数据集及算法特性给定:

①COCO共有80个类,原始图片全部resize为416 × 416;

②YOLOv3分别在尺度13 x 13, 26 x26, 52 x52上执行检测;

③在每个尺度上,每个单元使用 3 个锚点预测 3 个边界框,锚点的总数为 9,v3中每个尺度上平均检测三个锚点;

④在进行检测时,九个框分别是 (10×13),(16×30),(33×23),(30×61),(62×45),(59× 119), (116 × 90), (156 × 198),(373 × 326) ,顺序为w × h,数据依次从大到小排列。

三、自定义开发时涉及TensorFlow源码目录

tensorflow/core:

----kernels:Kernel的具体实现

----ops:OP的注册与声明

tensorflow/stream_executor:

运行时环境,管理TF中高性能并行编程设备的执行过程(限制哪些任务可以并发执行并指定存在哪些任务依赖项...)

----mlu:mlu执行引擎所使用的子模块

四、思考

在进行自定义算子开发时,包含注册与实现两个部分,在只涉及CPU平台时,完成OP注册后,可直接进行Kernel的实现,最后完成Kernel的注册。若要让算子能够在一些高效的设备上运行,需要针对该类设备做针对性开发其过程可分为以下几部分:

1. 自定义算子的Kernel函数实现;

2. 调用设备提供的接口封装出便于用户使用的接口;

3. 对2所述接口进一步封装,以有效隔离高层调用与底层实现;

4. 完成设备端算子实现(本质是3所述接口的调用);

5. 设备端算子实例化,运行时会自动将算子与运行时队列绑定并下发执行;

6. OPKernel实现,创建输入输出Tensor并初始化,调用5中实例化的算子;

7. OP注册与Kernel注册。

五、参考资料

《AICS》实验7-1流程:https://blog.csdn.net/weixin_40943865/article/details/122059436

StreamExecutor:https://blog.csdn.net/qq_36178899/article/details/84521479

OpShapeInference:​​​​​​​​​​​​​​https://blog.csdn.net/HaoBBNuanMM/article/details/115352223

自定义OP实现:https://docs.pythontab.com/tensorflow/how_tos/adding_an_op/#op-kernel

你可能感兴趣的:(tensorflow,人工智能,python)