【TVM系列五】添加Relay自定义算子

一、前言

本文以实现一个axis_abs的自定义算子为例介绍如何在tvm中添加新的relay算子,该算子实现的功能是以输入的3维tensor取某一维度的指定切片取绝对值。

二、添加自定义算子

新增relay算子基本是下面几个步骤:

  • 定义新增算子的属性节点(Attribute Node),声明在编译时已知的固定参数;

  • 为新增算子编写类型关系,以集成到relay的类型系统中;

  • 使用C++RELAY_REGISTER_OP宏,为新增算子注册生命参数数量、类型、提示信息;

  • 算子的compute实现;

  • 注册算子的compute、schedule;

  • 定义C++函数,为新增算子生成调用节点,并为该函数注册 Python API hook;

  • 将上面的 Python API hook 封装成简洁的调用方式;

  • 为新的relay 算子编写测试。

1、定义新增算子的属性节点(Attribute Node)

在include/tvm/relay/attrs/transform.h中增加算子的属性数据结构:

/*! \brief Attributes used in axisabs operator */
struct AxisAbsAttrs : public tvm::AttrsNode {
    int axis;
    int indice;

    TVM_DECLARE_ATTRS(AxisAbsAttrs, "relay.attrs.AxisAbsAttrs") {
        TVM_ATTR_FIELD(axis).set_default(0).describe("Axis to abs");
        TVM_ATTR_FIELD(indice).set_default(0).describe("Indice to abs");
    }
};

Q:宏TVM_DECLARE_ATTRS 与 TVM_ATTR_FIELD的作用是什么?
A:这两个宏定义在 include/tvm/ir/attrs.h

#define TVM_DECLARE_ATTRS(ClassName, TypeKey)                    \
  static constexpr const char* _type_key = TypeKey;              \
  TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
  template                                      \
  void __VisitAttrs__(FVisit& __fvisit__)  // NOLINT(*)

#define TVM_ATTR_FIELD(FieldName) __fvisit__(#FieldName, &FieldName)

其中的TVM_DECLARE_FINAL_OBJECT_INFO定义在include/tvm/runtime/object.h

#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \
   static const constexpr bool _type_final = true;           \
   static const constexpr int _type_child_slots = 0;         \
   TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
  
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)                                     \
   static_assert(!ParentType::_type_final, "ParentObj marked as final");                        \
   static uint32_t RuntimeTypeIndex() {                                                         \
     static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 ||    \
                      TypeName::_type_child_slots < ParentType::_type_child_slots,             \
                  "Need to set _type_child_slots when parent specifies it.");                  \
     if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) {                        \
       return TypeName::_type_index;                                                            \
     }                                                                                          \
     return _GetOrAllocRuntimeTypeIndex();                                                      \
   }                                                                                            \
  static uint32_t _GetOrAllocRuntimeTypeIndex() {                                              \
    static uint32_t tindex = Object::GetOrAllocRuntimeTypeIndex(                               \
        TypeName::_type_key, TypeName::_type_index, ParentType::_GetOrAllocRuntimeTypeIndex(), \
        TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow);                \
    return tindex;                                                                             \
  }

所以宏展开后定义的属性节点数据结构为:

struct AxisAbsAttrs : public tvm::ArrayNode {
    int axis;    
    static constexpr const char* _type_key = "relay.attrs.AxisAbsAttrs";
    static const constexpr bool _type_final = true;
    static const constexpr int _type_child_slots = 0;

    static_assert(!::tvm::BaseAttrsNode::_type_final, "ParentObj marked as final");

    static uint32_t RuntimeTypeIndex() {                                                       
        static_assert(AxisAbsAttrs::_type_child_slots == 0 || ::tvm::BaseAttrsNode::_type_child_slots == 0 ||    
                          AxisAbsAttrs::_type_child_slots < ::tvm::BaseAttrsNode::_type_child_slots,             
                      "Need to set _type_child_slots when parent specifies it.");                  
        if (AxisAbsAttrs::_type_index != ::tvm::runtime::TypeIndex::kDynamic) {                        
            return AxisAbsAttrs::_type_index;                                                            
        }                                                                                          
         return _GetOrAllocRuntimeTypeIndex();                                                      
      }           

    static uint32_t _GetOrAllocRuntimeTypeIndex() {                                              
         static uint32_t tindex = Object::GetOrAllocRuntimeTypeIndex(                               
         AxisAbsAttrs::_type_key, AxisAbsAttrs::_type_index, ::tvm::BaseAttrsNode::_GetOrAllocRuntimeTypeIndex(), 
         AxisAbsAttrs::_type_child_slots, AxisAbsAttrs::_type_child_slots_can_overflow);                
         return tindex;                                                                             
    }

    template                                     
    void __VisitAttrs__(FVisit& __fvisit__)  {
        __fvisit__(axis, &axis).set_default(0).describe("Axis to abs");
    }
}

可以看到,每个属性节点都定义了获取运行时类型索引的函数RuntimeTypeIndex()以及访问属性内部成员的模版函数VisitAttrs(FVisit& fvisit)。

Q:模版函数VisitAttrs(FVisit& fvisit)的调用过程是怎么样的?
A:首先分析定义在include/tvm/ir/attrs.h中的类class AttrsNode

template 
class AttrsNode : public BaseAttrsNode {
public:
  void VisitAttrs(AttrVisitor* v) {
    ::tvm::detail::AttrNormalVisitor vis(v);
    self()->__VisitAttrs__(vis);
  }
  void VisitNonDefaultAttrs(AttrVisitor* v) {...}
  void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final {...}
  bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const {...}
  void SHashReduce(SHashReducer hash_reducer) const {...}
  Array ListFieldInfo() const final {...}
private:
  DerivedType* self() const {
    return const_cast(static_cast(this));
  }
};

它是一个模版类,模版参数是继承它的子类类型,在成员函数VisitAttrs(AttrVisitor* v)中,传入属性访问器类AttrVisitor对象:

class AttrVisitor {
 public:
  //! \cond Doxygen_Suppress
  TVM_DLL virtual ~AttrVisitor() = default;
  TVM_DLL virtual void Visit(const char* key, double* value) = 0;
  TVM_DLL virtual void Visit(const char* key, int64_t* value) = 0;
  TVM_DLL virtual void Visit(const char* key, uint64_t* value) = 0;
  TVM_DLL virtual void Visit(const char* key, int* value) = 0;
  TVM_DLL virtual void Visit(const char* key, bool* value) = 0;
  TVM_DLL virtual void Visit(const char* key, std::string* value) = 0;
  TVM_DLL virtual void Visit(const char* key, void** value) = 0;
  TVM_DLL virtual void Visit(const char* key, DataType* value) = 0;
  TVM_DLL virtual void Visit(const char* key, runtime::NDArray* value) = 0;
  TVM_DLL virtual void Visit(const char* key, runtime::ObjectRef* value) = 0;
  template ::value>::type>
  void Visit(const char* key, ENum* ptr) {
    static_assert(std::is_same::type>::value,
                  "declare enum to be enum int to use visitor");
    this->Visit(key, reinterpret_cast(ptr));
  }
  //! \endcond
};

然后通过::tvm::detail::AttrNormalVisitor vis(v);包裹一层普通属性访问函数:

// Wrapper for normal visitor.
class AttrNormalVisitor {
public:
  explicit AttrNormalVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
  template 
  AttrNopEntry operator()(const char* key, T* value) {
    visitor_->Visit(key, value);
    return AttrNopEntry();
  }

private:
  AttrVisitor* visitor_;
};

它重载了运算符“()”,当class AttrsNode通过self()->VisitAttrs(vis)获取子类的对象并通过子类对象调用VisitAttrs(FVisit& fvisit) 时,随即调用了fvisit(axis, &axis),这个fvisit最终调到的就是class AttrNormalVisitor 中的重载"()"函数,这个函数会返回一个结构体用于支持链式调用:

// helper entry that does nothing in set_default/bound/describe calls.
struct AttrNopEntry {
  using TSelf = AttrNopEntry;
  TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
  template 
  TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {return *this;}
  template 
  TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {return *this;}
  template 
  TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {return *this;}
};

这些调用实际上什么都没有做就返回了其自身。

2、编写算子类型关系,集成到Relay的类型系统

为了算子注册的灵活性以及relay算子有更好的泛化能力,relay算子通过输入输出之间的类型关系来实例化。本质上,算子类型关系除了推导输出类型外,还能够强制指定类型规则(检查输入类型)。需要在src\relay\op\tensor\transform.cc中添加算子的类型关系处理函数:

bool AxisAbsRel(const Array& types, int num_inputs, const Attrs& attrs,
               const TypeReporter& reporter) {
    // types: [data, output]
    ICHECK_EQ(types.size(), 2);
    const auto* data = types[0].as();
    if (data == nullptr) {
      ICHECK(types[0].as())
          << "cast: expect input type to be TensorType but get " << types[0];
      return false;
    }
    const auto* param = attrs.as();
    const int ndim = static_cast(data->shape.size());
    const int axis = param->axis;
    const int axis_len = data->shape[axis].as()->value;
    const int indice = param->indice;

    ICHECK(0 <= axis && axis < ndim)
      << "axis_abs only accepts `axis` in [0, data.ndim - 1]"
      << ", but got axis = " << axis << ", and data.ndim = " << ndim;

    ICHECK(0 <= indice && indice < axis_len)
      << "axis_abs only accepts `indice` in [0, data[axis] - 1"
      << ", but got indice = " << indice << ", and data[axis] = " << axis_len;

    reporter->Assign(types[1], TensorType(data->shape, data->dtype));
    return true;
}

Q:类型关系处理函数在什么时候调用?
A:类型关系处理函数在注册Relay算子时通过链式调用add_type_rel()注册。

Q:函数输入参数types的含意是什么?
A:types传入的是一个数组引用,内容一般为输入与输出的TensorType,首先看class TensorTypeNode:

class TensorTypeNode : public BaseTensorTypeNode {
 public:

  Array shape;   // Tensor的shape
  DataType dtype;    // Tensor中数据类型
  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("shape", &shape);
    v->Visit("dtype", &dtype);
    v->Visit("span", &span);
  }
  bool SEqualReduce(const TensorTypeNode* other, SEqualReducer equal) const {...}
  void SHashReduce(SHashReducer hash_reduce) const {...}
  TVM_DLL PrimExpr Size() const;
  static constexpr const char* _type_key = "relay.TensorType";
  TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, BaseTensorTypeNode);
};

它定义了一个Tensor所需要的基本数据信息如:shape与数据类型,但是并没有实际的数据,所以类名也就叫TensorTypeNode。通过它可以获取到输入Tensor的类型信息从而对参数做合法性检查。

Q:函数输入参数reporter的含意是什么?
A:class TypeReporter是一个TypeReporterNode的容器类:

class TypeReporter : public ObjectRef {
 public:
  TypeReporter() {}
  explicit TypeReporter(ObjectPtr n) : ObjectRef(n) {}
  TypeReporterNode* operator->() const {
    return const_cast(static_cast(get()));
  }
  using ContainerType = TypeReporterNode;
};
 
 

它重载了运算符"->",所以:

reporter->Assign(types[1], TensorType(data->shape, data->dtype));

首先会实例化一个TensorType对象,因为我们的例子是对某一个维度的数据取绝对值,所以输出的数据shape及dtype与输入相同。然后通过reporter->Assign()调用class TypeReporterNode中纯虚函数virtual void Assign(dst, src) = 0,将创建好的TensorType对象赋值给输出TensorType,即types[1]。

3、关联算子的参数数目、属性

这一步的操作,为自定义算子注册算子名称,通过调用接口增加算子注释。这里需要用到C++的宏RELAY_REGISTER_OP,涉及的参数含义如下:

  • Arity(参数数量)

  • 位置参数的名称和描述

  • 支持级别(1 表示内部实现;较高的数字表示较少的内部支持或外部支持的算子)

  • 算子的类型关系

  • 优化算子时有用的其他注释。

需要在src/relay/op/tensor/transform.cc中注册算子并设置相关属性:

RELAY_REGISTER_OP("axis_abs")
    .describe(R"doc(Computes the axis abs of a tensor.)doc") TVM_ADD_FILELINE)
    .set_num_inputs(1)
    .add_argument("data", "Tensor", "The input tensor")
    .set_support_level(3)
    .add_type_rel("axis_abs", AxisAbsRel)
    .set_attr("TOpPattern", kOpaque);

Q:宏RELAY_REGISTER_OP做了什么?
A:RELAY_REGISTER_OP用于注册Relay算子:

#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y) 
#define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_##Op

#define TVM_REGISTER_OP(OpName)                          \
  TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \
      ::tvm::OpRegEntry::RegisterOrGet(OpName).set_name()

展开为:

static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_Op0=::tvm::OpRegEntry::RegisterOrGet(OpName).set_name()

其中COUNTER为编译器内置宏,初值是0,每预编译一次其值自己加1,通常配合 ## 使用,用于构建唯一的标识符,做法其实很简单,把任意一个标识符与 COUNTER 合并就可以了:

#define STR_CONCAT_(x, y) x##y  // 合并用的宏
#define STR_CONCAT(x, y) STR_CONCAT_(x, y)    // 因为 ## 的特性 ( 阻止另一个宏的展开 ),需要中间层
#define UNIQUE_NAME(name) STR_CONCAT(name, __COUNTER__)  // 把标识符与 __COUNTER__合并, 就可以建立唯一的变数名称了

而::tvm::OpRegEntry::RegisterOrGet(OpName)通过算子名称在全局的算子注册机对象中查找算子并返回OpRegEntry对象:

OpRegEntry& OpRegEntry::RegisterOrGet(const String& name) {
  return OpRegistry::Global()->RegisterOrGet(name);
}

Q:类OpRegEntry定义了什么?
A:类的定义在include/tvm/ir/op.h:

class OpRegEntry {public:
  const Op& op() const { return op_; }
  inline OpRegEntry& describe(const std::string& descr); 
  inline OpRegEntry& add_argument(const std::string& name, const std::string& type,
                                  const std::string& description);
  inline OpRegEntry& add_type_rel(const std::string& rel_name,
      runtime::TypedPackedFunc&, int, const Attrs&, const TypeReporter&)> type_rel_func);
  template 
  inline OpRegEntry& set_attrs_type();
  inline OpRegEntry& set_attrs_type_key(const String& key);
  inline OpRegEntry& set_num_inputs(int32_t n); 
  inline OpRegEntry& set_support_level(int32_t level); 
  template 
  inline OpRegEntry& set_attr(const std::string& attr_name, 
                              const ValueType& value, int plevel = 10);
  inline void reset_attr(const std::string& attr_name);
  inline OpRegEntry& set_name() {...}
  TVM_DLL static OpRegEntry& RegisterOrGet(const String& name);
private:
  template 
  friend class AttrRegistry;
  std::string name;
  Op op_;
  TVM_DLL OpRegEntry(uint32_t reg_index);
  inline OpNode* get()
  TVM_DLL void UpdateAttr(const String& key, runtime::TVMRetValue value, int plevel);
};

可以看到,大部分成员函数都是返回自身的指针,从而方便链式调用,它们的实现代码与定义在同一个头文件中。其中的get()私有成员函数会返回OpNode指针,其它成员函数通过get()来获取算子节点的指针:

inline OpNode* OpRegEntry::get() { return const_cast(op_.operator->()); }

Q:链式调用中的几个函数作用是什么?
A:链式调用中的几个函数都是对OpNode节点对象的成员进行赋值,所以需要了解class OpNode的定义:

class OpNode : public RelayExprNode {
public:
  String name;  // 算子的名称
  mutable FuncType op_type;  // 算子的类型
  String description;  // 算子的具体描述,可以用在自动生成说明文档
  Array arguments;  // 算子的输入参数信息
  String attrs_type_key;  // 属性字段的类型键值,可以为空
  uint32_t attrs_type_index{0};  // 属性的类型索引
  int32_t num_inputs = -1;  // 算子输入参数个数,-1表示可变长
  int32_t support_level = 10; // 算子的支持等级,值越低优先级越高。void VisitAttrs(AttrVisitor* v) {
    v->Visit("name", &name);
    v->Visit("op_type", &op_type);
    v->Visit("description", &description);
    v->Visit("arguments", &arguments);
    v->Visit("attrs_type_key", &attrs_type_key);
    v->Visit("num_inputs", &num_inputs);
    v->Visit("support_level", &support_level);
  }
  ...
  static constexpr const char* _type_key = "Op";
  TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelayExprNode);
private:
   ...
};

所以在链式调用的函数中,只做了简单赋值的函数是:

  • describe()就是给OpNode->description赋值;

  • set_num_inputs()是设置输入参数个数;

  • set_support_level()是设置支持等级;

  • add_argument()为arguments数组添加元素;

  • set_attr<>()会调用class AttrRegistry中的UpdateAttr()方法进行属性更新。

其中,对于add_argument(),因为TVM将每个算子的参数都用AttrFieldInfo描述,而AttrFieldInfo实际的内容是AttrFieldInfoNode:

class AttrFieldInfoNode : public Object {
 public:
  String name; // 字段名称
  String type_info;  // 类型说明
  String description; // 详细描述

  void VisitAttrs(AttrVisitor* v) {
    v->Visit("name", &name);
    v->Visit("type_info", &type_info);
    v->Visit("description", &description);
  }

  static constexpr const char* _type_key = "AttrFieldInfo";
  static constexpr bool _type_has_method_sequal_reduce = false;
  static constexpr bool _type_has_method_shash_reduce = false;
  TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
};

所以add_argument()在赋值前,会创建一个AttrFieldInfoNode对象再把它放入到arguments数组中:

inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const std::string& type,
                                            const std::string& description) {
  auto n = make_object();
  n->name = name;
  n->type_info = type;
  n->description = description;
  get()->arguments.push_back(AttrFieldInfo(n));
  return *this;
}

需要详细说明的是.add_type_rel() 和.set_attr()。

Q:add_type_rel()流程是怎么样的?
A:在函数中会创建输入与输出的TypeVarNode,然后创建TypeRelationNode将类型关系函数管理起来,并定义一个FuncTypeNode将这些定义好的对象作为输入,最终赋值给op_type。

inline OpRegEntry& OpRegEntry::add_type_rel(
    const std::string& rel_name,
    runtime::TypedPackedFunc&, int, const Attrs&, const TypeReporter&)>
        type_rel_func) {
  auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
  TypeRelationFn env_type_rel_func;

  if (runtime::Registry::Get(func_name)) {
    auto env_func = EnvFunc::Get(func_name);
    env_type_rel_func = env_func;
  } else {
    runtime::Registry::Register(func_name).set_body(type_rel_func.packed()); // 创建registy注册type_rel_func
    auto env_func = EnvFunc::Get(func_name);
    env_type_rel_func = env_func; // 这个跟第二小节定义的类型关系函数相关联
  }

  Array type_params;  // TypeVar是Type的子类
  Array arg_types;
  // Add inputs.
  std::string input_name_prefix = "in";
  for (int i = 0; i < get()->num_inputs; i++) {
    auto name = input_name_prefix + std::to_string(i);
    auto param = TypeVar(name, TypeKind::kType);  // 创建一个TypeVarNode对象
    type_params.push_back(param);
    arg_types.push_back(param);
  }
  Array ty_call_args = arg_types;

  // Add output type.
  auto out_param = TypeVar("out", TypeKind::kType);
  type_params.push_back(out_param);
  ty_call_args.push_back(out_param);
  TypeConstraint type_rel =
  TypeRelation(env_type_rel_func, ty_call_args, arg_types.size(), Attrs());// 创建TypeRelationNode
  auto func_type = FuncType(arg_types, out_param, type_params, {type_rel}); // 创建FuncTypeNode
  get()->op_type = func_type;  // 对op_type成员赋值  
  return *this;
}

TypeRelation()会创建一个TypeRelationNode,它实际上保存了之前定义的类型关系函数的相关信息:

class TypeRelationNode : public TypeConstraintNode {
public:
  TypeRelationFn func;  
  Array args; // The type arguments to the type function.
  int num_inputs; // Number of inputs arguments
  Attrs attrs; // Attributes to the relation function
  void VisitAttrs(AttrVisitor* v) {
    v->Visit("func", &func);
    v->Visit("args", &args);
    v->Visit("num_inputs", &num_inputs);
    v->Visit("attrs", &attrs);
    v->Visit("span", &span);
  }
  ...
}

FuncType()创建FuncTypeNode,将定义的输入、输出、参数类型和类型关系节点作为输入:

class FuncTypeNode : public TypeNode {
 public:
  Array arg_types; // type type of arguments
  Type ret_type; // The type of return value
  Array type_params; // The type parameters of the function
  Array type_constraints; // potential constraint the type need to obey
  void VisitAttrs(AttrVisitor* v) {
    v->Visit("arg_types", &arg_types);
    v->Visit("ret_type", &ret_type);
    v->Visit("type_params", &type_params);
    v->Visit("type_constraints", &type_constraints);
    v->Visit("span", &span);
  }
  ...
}
4、算子compute实现

有两种方式实现算子计算过程:

(1)python端实现计算

在python/tvm/topi/transform.py添加:

def axis_abs(x, axis, indice):
    """Take absolute value of the input of axis in x, element-wise.

    Parameters
    ----------
    x : tvm.te.Tensor
        Input argument.
    axis: int
        Input argument.
    indice: int
        Input argument.
    Returns
    -------
    y : tvm.te.Tensor
        The result.
    """
    ishape = x.shape
    assert len(ishape) == 3
    assert indice < get_const_int(ishape[axis])
    assert indice >= 0
    if axis == 0:
        return te.compute(x.shape, lambda i,j,k: te.if_then_else(x[i,j,k] >= 0, x[i,j,k],
                            te.if_then_else(i == indice, -x[i,j,k], x[i,j,k])))
    elif axis == 1:
        return te.compute(x.shape, lambda i, j, k: te.if_then_else(x[i, j, k] >= 0, x[i, j, k],
                            te.if_then_else(j == indice, -x[i, j, k], x[i, j, k])))
    else:
        return te.compute(x.shape, lambda i, j, k: te.if_then_else(x[i, j, k] >= 0, x[i, j, k],
                            te.if_then_else(k == indice, -x[i, j, k], x[i, j, k])))

并且在python/tvm/relay/op/_transform.py中设置算子计算函数属性:

@_reg.register_compute("axis_abs")  # 设置算子的计算函数属性,默认的level为10
def compute_axis_abs(attrs, inputs, output_type):
    """Compute definition of axis_abs"""
    return topi.axis_abs(inputs[0], attrs.axis, attrs.indice)

(2)C++端实现计算

在src/relay/op/tensor/transform.cc添加算子计算函数:

Array AxisAbsCompute(const Attrs& attrs, const Array& inputs,
                                    const Type& out_type) {
    // TODO
}

并且调用RELAY_REGISTER_OP("axis_abs")注册算子时需要设置它的计算函数属性:

.set_attr("FTVMCompute", AxisAbsCompute)

此时在python/tvm/topi/transform.py中的算子实现可以直接调用cpp的代码:

def axis_abs(x, axis, indice):
    """Take absolute value of the input of axis in x, element-wise.

    Parameters
    ----------
    x : tvm.te.Tensor
        Input argument.
    axis: int
        Input argument.
    indice: int
        Input argument.
    Returns
    -------
    y : tvm.te.Tensor
        The result.
    """
    return cpp.axis_abs(x, axis, indice)
5、注册算子的compute、schedule

在实现了算子compute逻辑以后,需要与我们实现的算子接口绑定在一起。在TVM中,这就需要不仅实现算子的compute接口,还要实现对应的schedule。而strategy就是对compute选择合适的schedule。需要在python/tvm/relay/op/strategy/generic.py添加算子的计算策略:

def wrap_compute_axis_abs(topi_compute):
    """Wrap axis_abs topi compute"""

    def _compute_axis_abs(attrs, inputs, _):
        return [topi_compute(inputs[0], attrs.axis, attrs.indice)]

    return _compute_axis_abs

@override_native_generic_func("axis_abs_strategy")
def axis_abs_strategy(attrs, inputs, out_type, target):
    """axis_abs generic strategy"""
    strategy = _op.OpStrategy()
    strategy.add_implementation(
        wrap_compute_axis_abs(topi.axis_abs),
        wrap_topi_schedule(topi.generic.schedule_injective),
        name="axix_abs.generic",
    )
    return strategy

在python/tvm/relay/op/_transform.py中将算子与计算策略关联起来:

_reg.register_strategy("axis_abs", strategy.axis_abs_strategy)
6、为算子生成调用节点并注册 API hook

现在有一个可以调用的relay算子,下一步就是如何通过relay call node调用。这就需要实现一个函数,传递相应的参数给对应的relay算子,并且返回对应算子的Call Node(这个算子最终在Relay表达式的AST里面)。需要在src\relay\op\tensor\transform.cc添加:

Expr MakeAxisAbs(Expr data, int axis, int indice) {
    auto attrs = make_object();
    attrs->axis = axis;
    attrs->indice = indice;
    static const Op& op = Op::Get("axis_abs");
    return Call(op, {data}, Attrs(attrs), {}); // 会创建一个CallNode实例
}

TVM_REGISTER_GLOBAL("relay.op._make.axis_abs").set_body_typed(MakeAxisAbs);

Q:Call Node是什么?
A:CallNode类是ExprNode的子类,它在程序调用Call函数时被实例化:

class CallNode : public ExprNode {
 protected:
  Object::FDeleter saved_deleter_;
  static void Deleter_(Object* ptr);
 public:
  Expr op; // 算子的计算表达函数
  tvm::Array args;  // call函数的输入参数
  Attrs attrs; // 属性
  tvm::Array type_args;  // 传递给多态(模板)函数的类型参数
  void VisitAttrs(tvm::AttrVisitor* v) {
    v->Visit("op", &op);
    v->Visit("args", &args);
    v->Visit("attrs", &attrs);
    v->Visit("type_args", &type_args);
    v->Visit("span", &span);
    v->Visit("_checked_type_", &checked_type_);
  }
  bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {...}
  void SHashReduce(SHashReducer hash_reduce) const {...}
  static constexpr const char* _type_key = "relay.Call";
  TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
  template 
  friend class runtime::ObjAllocatorBase;
  friend class Call;
};

Call::Call(Expr op, Array args, Attrs attrs, Array type_args, Span span) {
  ObjectPtr n = make_object();  // 创建CallNode
  n->op = std::move(op);
  n->args = std::move(args);
  n->attrs = std::move(attrs);
  n->type_args = std::move(type_args);
  n->span = std::move(span); 
  data_ = std::move(n);
}
7、将Python API hook 封装成简洁的调用方式

为更方便的使用,通常的做法是构造单独的函数,因此最好封装成更简洁的python接口,需要在python/tvm/relay/op/transform.py中添加:

def axis_abs(data, axis=0, indice=0):
    """Computes abs of data along a certain axis indice.

    Parameters
    ----------
    data : relay.Expr
        The source data to be invert permuated.

    Returns
    -------
    ret : relay.Expr
        Invert permuated data. Has the same type as data.
    """
    return _make.axis_abs(data, axis, indice)
8、为新的relay 算子编写测试用例

需要在tvm/tests/python/test_op_level3.py添加:

class TestAxisAbs:
    dshape, axis, indice = tvm.testing.parameters(     # 定义测试用例参数,这里是输入tensor的shape,axis和indice
        ((4, 4, 1), 1, 1),
        ((4, 4, 1), 0, 1),
        ((3, 3, 3), 1, 1),
    )

    def test_axis_abs(self, dshape, axis, indice):
        x = relay.var("x", relay.TensorType(dshape, "int32"))  # 定义relay输入tensor
        y = relay.axis_abs(x, axis=axis, indice=indice)    # 定义axis_abs运算表达式
        yy = run_infer_type(y)      # 推理运算表达式的类型,定义在python/tvm/relay/testing/__init__.py
        assert yy.checked_type == relay.TensorType(dshape, "int32")  # 类型测试

        data = np.random.randint(-5, 5, size=dshape).astype("int32")
        op_res = create_executor().evaluate(y, {x: relay.const(data)})  # 创建执行器并执行算子推理
        if axis == 0:
            data[indice,:,:] = np.abs(data[indice,:,:])
        elif axis == 1:
            data[:,indice, :] = np.abs(data[:,indice,:])
        else:
            data[:,:,indice] = np.abs(data[:,:,indice])
        ref_res = data
        np.testing.assert_equal(op_res.numpy(), ref_res)  # 对比numpy结果与relay的计算结果

如果没有安装pytest,要先安装pytest,再运行测试用例:

pip install pytest && cd tvm/tests/python && pytest relay/test_op_level3.py::TestAxisAbs

用例通过的结果如下:

image.png

三、总结

本文根据TVM官方文档给出的步骤添加了一个自定义算子,并对过程中有疑问的地方做了一些说明。

你可能感兴趣的:(【TVM系列五】添加Relay自定义算子)