<小米开源框架MACE> 自定义新的Op

如果遇到框架不支持的Op,用户可以自行创建。MACE提供了非常友好的创建自定义Op的操作,只需按照以下步骤执行。

1. 定义Op类

创建新文件mace/ops/my_custom_op.h,并在文件中定义新的类,例如MyCustomOp,代码如下

#ifndef MACE_OPS_MY_CUSTOM_OP_H_
#define MACE_OPS_MY_CUSTOM_OP_H_

#include "mace/core/operator.h"
#include "mace/kernels/my_custom_op.h"

namespace mace {
namespace ops {

template 
class MyCustomOp : public Operator {
  public:
    MyCustomOp(const OperatorDef &op_def, Workspace *ws)
      : Operator(op_def, ws),
        functor_() {}

    bool Run(StatsFuture *future) override {
      const Tensor *input = this->Input(INPUT);
      Tensor *output = this->Output(OUTPUT);

      functor_(input, output, future);
      return true;
    }

  protected:
    OP_INPUT_TAGS(INPUT);
    OP_OUTPUT_TAGS(OUTPUT);

  private:
    kernels::MyCustomOpFunctor functor_;
};

} // namespace ops
} // namespace mace

#endif // MACE_OPS_MY_CUSTOM_OP_H_
2. 注册Op类

创建新文件 mace/ops/my_custom_op.cc,并在其中注册新建的类,代码如下。从代码里可以看出,注册代码提供了CPU和GPU的两种实现,其中GPU版本包含floathalf的两种数据支持。

除此之外,还需要在mace/core/operator.cc文件的namespace ops下按模板添加注册代码extern void Register_My_Custom_Op(OperatorRegistry *op_registry);,同时在OperatorRegistry()构造函数中添加ops::Register_My_Custom_Op(this);

#include "mace/ops/my_custom_op.h"

namespace mace {
namespace ops {

void Register_My_Custom_Op(OperatorRegistry *op_registry) {
  REGISTER_OPERATOR(op_registry, OpKeyBuilder("my_custom_op")
                                    .Device(DeviceType::CPU)
                                    .TypeConstraint("T")
                                    .Build(),
                      Custom_Op);

  REGISTER_OPERATOR(op_registry, OpKeyBuilder("my_custom_op")
                                    .Device(DeviceType::OPENCL)
                                    .TypeConstraint("T")
                                    .Build(),
                      Custom_Op);

  REGISTER_OPERATOR(op_registry, OpKeyBuilder("my_custom_op")
                                    .Device(DeviceType::OPENCL)
                                    .TypeConstraint("T")
                                    .Build(),
                      Custom_Op);
}

} // namespace ops
} // namespace mace
3. 实现Op的核心代码

注册完成Op后,需要实现其核心代码。创建文件mace/kernels/my_custom_op.h,这个文件文件实现的是CPU版本的代码。也可以选择实现OpenCL版本代码,这时,需要创建两个文件:mace/kernels/opencl/my_custom_op_opencl.ccmace/kernels/opencl/cl/ my_custom_op.cl。也可以对CPU版本的实现进行NEON指令集优化。

4. 添加Op的测试代码
5. 添加Op的文档信息

你可能感兴趣的:(<小米开源框架MACE> 自定义新的Op)