一、前言
本文以实现一个axis_abs的自定义算子为例介绍如何在tvm中添加新的relay算子,该算子实现的功能是以输入的3维tensor取某一维度的指定切片取绝对值。
二、添加自定义算子
新增relay算子基本是下面几个步骤:
定义新增算子的属性节点(Attribute Node),声明在编译时已知的固定参数;
为新增算子编写类型关系,以集成到relay的类型系统中;
使用C++RELAY_REGISTER_OP宏,为新增算子注册生命参数数量、类型、提示信息;
算子的compute实现;
注册算子的compute、schedule;
定义C++函数,为新增算子生成调用节点,并为该函数注册 Python API hook;
将上面的 Python API hook 封装成简洁的调用方式;
为新的relay 算子编写测试。
1、定义新增算子的属性节点(Attribute Node)
在include/tvm/relay/attrs/transform.h中增加算子的属性数据结构:
/*! \brief Attributes used in axisabs operator */
struct AxisAbsAttrs : public tvm::AttrsNode {
int axis;
int indice;
TVM_DECLARE_ATTRS(AxisAbsAttrs, "relay.attrs.AxisAbsAttrs") {
TVM_ATTR_FIELD(axis).set_default(0).describe("Axis to abs");
TVM_ATTR_FIELD(indice).set_default(0).describe("Indice to abs");
}
};
Q:宏TVM_DECLARE_ATTRS 与 TVM_ATTR_FIELD的作用是什么?
A:这两个宏定义在 include/tvm/ir/attrs.h
#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
static constexpr const char* _type_key = TypeKey; \
TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
template \
void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*)
#define TVM_ATTR_FIELD(FieldName) __fvisit__(#FieldName, &FieldName)
其中的TVM_DECLARE_FINAL_OBJECT_INFO定义在include/tvm/runtime/object.h
#define TVM_DECLARE_FINAL_OBJECT_INFO(TypeName, ParentType) \
static const constexpr bool _type_final = true; \
static const constexpr int _type_child_slots = 0; \
TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType)
#define TVM_DECLARE_BASE_OBJECT_INFO(TypeName, ParentType) \
static_assert(!ParentType::_type_final, "ParentObj marked as final"); \
static uint32_t RuntimeTypeIndex() { \
static_assert(TypeName::_type_child_slots == 0 || ParentType::_type_child_slots == 0 || \
TypeName::_type_child_slots < ParentType::_type_child_slots, \
"Need to set _type_child_slots when parent specifies it."); \
if (TypeName::_type_index != ::tvm::runtime::TypeIndex::kDynamic) { \
return TypeName::_type_index; \
} \
return _GetOrAllocRuntimeTypeIndex(); \
} \
static uint32_t _GetOrAllocRuntimeTypeIndex() { \
static uint32_t tindex = Object::GetOrAllocRuntimeTypeIndex( \
TypeName::_type_key, TypeName::_type_index, ParentType::_GetOrAllocRuntimeTypeIndex(), \
TypeName::_type_child_slots, TypeName::_type_child_slots_can_overflow); \
return tindex; \
}
所以宏展开后定义的属性节点数据结构为:
struct AxisAbsAttrs : public tvm::ArrayNode {
int axis;
static constexpr const char* _type_key = "relay.attrs.AxisAbsAttrs";
static const constexpr bool _type_final = true;
static const constexpr int _type_child_slots = 0;
static_assert(!::tvm::BaseAttrsNode::_type_final, "ParentObj marked as final");
static uint32_t RuntimeTypeIndex() {
static_assert(AxisAbsAttrs::_type_child_slots == 0 || ::tvm::BaseAttrsNode::_type_child_slots == 0 ||
AxisAbsAttrs::_type_child_slots < ::tvm::BaseAttrsNode::_type_child_slots,
"Need to set _type_child_slots when parent specifies it.");
if (AxisAbsAttrs::_type_index != ::tvm::runtime::TypeIndex::kDynamic) {
return AxisAbsAttrs::_type_index;
}
return _GetOrAllocRuntimeTypeIndex();
}
static uint32_t _GetOrAllocRuntimeTypeIndex() {
static uint32_t tindex = Object::GetOrAllocRuntimeTypeIndex(
AxisAbsAttrs::_type_key, AxisAbsAttrs::_type_index, ::tvm::BaseAttrsNode::_GetOrAllocRuntimeTypeIndex(),
AxisAbsAttrs::_type_child_slots, AxisAbsAttrs::_type_child_slots_can_overflow);
return tindex;
}
template
void __VisitAttrs__(FVisit& __fvisit__) {
__fvisit__(axis, &axis).set_default(0).describe("Axis to abs");
}
}
可以看到,每个属性节点都定义了获取运行时类型索引的函数RuntimeTypeIndex()以及访问属性内部成员的模版函数VisitAttrs(FVisit& fvisit)。
Q:模版函数VisitAttrs(FVisit& fvisit)的调用过程是怎么样的?
A:首先分析定义在include/tvm/ir/attrs.h中的类class AttrsNode
template
class AttrsNode : public BaseAttrsNode {
public:
void VisitAttrs(AttrVisitor* v) {
::tvm::detail::AttrNormalVisitor vis(v);
self()->__VisitAttrs__(vis);
}
void VisitNonDefaultAttrs(AttrVisitor* v) {...}
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final {...}
bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const {...}
void SHashReduce(SHashReducer hash_reducer) const {...}
Array ListFieldInfo() const final {...}
private:
DerivedType* self() const {
return const_cast(static_cast(this));
}
};
它是一个模版类,模版参数是继承它的子类类型,在成员函数VisitAttrs(AttrVisitor* v)中,传入属性访问器类AttrVisitor对象:
class AttrVisitor {
public:
//! \cond Doxygen_Suppress
TVM_DLL virtual ~AttrVisitor() = default;
TVM_DLL virtual void Visit(const char* key, double* value) = 0;
TVM_DLL virtual void Visit(const char* key, int64_t* value) = 0;
TVM_DLL virtual void Visit(const char* key, uint64_t* value) = 0;
TVM_DLL virtual void Visit(const char* key, int* value) = 0;
TVM_DLL virtual void Visit(const char* key, bool* value) = 0;
TVM_DLL virtual void Visit(const char* key, std::string* value) = 0;
TVM_DLL virtual void Visit(const char* key, void** value) = 0;
TVM_DLL virtual void Visit(const char* key, DataType* value) = 0;
TVM_DLL virtual void Visit(const char* key, runtime::NDArray* value) = 0;
TVM_DLL virtual void Visit(const char* key, runtime::ObjectRef* value) = 0;
template ::value>::type>
void Visit(const char* key, ENum* ptr) {
static_assert(std::is_same::type>::value,
"declare enum to be enum int to use visitor");
this->Visit(key, reinterpret_cast(ptr));
}
//! \endcond
};
然后通过::tvm::detail::AttrNormalVisitor vis(v);包裹一层普通属性访问函数:
// Wrapper for normal visitor.
class AttrNormalVisitor {
public:
explicit AttrNormalVisitor(AttrVisitor* visitor) : visitor_(visitor) {}
template
AttrNopEntry operator()(const char* key, T* value) {
visitor_->Visit(key, value);
return AttrNopEntry();
}
private:
AttrVisitor* visitor_;
};
它重载了运算符“()”,当class AttrsNode通过self()->VisitAttrs(vis)获取子类的对象并通过子类对象调用VisitAttrs(FVisit& fvisit) 时,随即调用了fvisit(axis, &axis),这个fvisit最终调到的就是class AttrNormalVisitor 中的重载"()"函数,这个函数会返回一个结构体用于支持链式调用:
// helper entry that does nothing in set_default/bound/describe calls.
struct AttrNopEntry {
using TSelf = AttrNopEntry;
TSelf& describe(DMLC_ATTRIBUTE_UNUSED const char* str) { return *this; }
template
TSelf& set_default(DMLC_ATTRIBUTE_UNUSED const T& value) {return *this;}
template
TSelf& set_lower_bound(DMLC_ATTRIBUTE_UNUSED const T& begin) {return *this;}
template
TSelf& set_upper_bound(DMLC_ATTRIBUTE_UNUSED const T& end) {return *this;}
};
这些调用实际上什么都没有做就返回了其自身。
2、编写算子类型关系,集成到Relay的类型系统
为了算子注册的灵活性以及relay算子有更好的泛化能力,relay算子通过输入输出之间的类型关系来实例化。本质上,算子类型关系除了推导输出类型外,还能够强制指定类型规则(检查输入类型)。需要在src\relay\op\tensor\transform.cc中添加算子的类型关系处理函数:
bool AxisAbsRel(const Array& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, output]
ICHECK_EQ(types.size(), 2);
const auto* data = types[0].as();
if (data == nullptr) {
ICHECK(types[0].as())
<< "cast: expect input type to be TensorType but get " << types[0];
return false;
}
const auto* param = attrs.as();
const int ndim = static_cast(data->shape.size());
const int axis = param->axis;
const int axis_len = data->shape[axis].as()->value;
const int indice = param->indice;
ICHECK(0 <= axis && axis < ndim)
<< "axis_abs only accepts `axis` in [0, data.ndim - 1]"
<< ", but got axis = " << axis << ", and data.ndim = " << ndim;
ICHECK(0 <= indice && indice < axis_len)
<< "axis_abs only accepts `indice` in [0, data[axis] - 1"
<< ", but got indice = " << indice << ", and data[axis] = " << axis_len;
reporter->Assign(types[1], TensorType(data->shape, data->dtype));
return true;
}
Q:类型关系处理函数在什么时候调用?
A:类型关系处理函数在注册Relay算子时通过链式调用add_type_rel()注册。
Q:函数输入参数types的含意是什么?
A:types传入的是一个数组引用,内容一般为输入与输出的TensorType,首先看class TensorTypeNode:
class TensorTypeNode : public BaseTensorTypeNode {
public:
Array shape; // Tensor的shape
DataType dtype; // Tensor中数据类型
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
v->Visit("span", &span);
}
bool SEqualReduce(const TensorTypeNode* other, SEqualReducer equal) const {...}
void SHashReduce(SHashReducer hash_reduce) const {...}
TVM_DLL PrimExpr Size() const;
static constexpr const char* _type_key = "relay.TensorType";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, BaseTensorTypeNode);
};
它定义了一个Tensor所需要的基本数据信息如:shape与数据类型,但是并没有实际的数据,所以类名也就叫TensorTypeNode。通过它可以获取到输入Tensor的类型信息从而对参数做合法性检查。
Q:函数输入参数reporter的含意是什么?
A:class TypeReporter是一个TypeReporterNode的容器类:
class TypeReporter : public ObjectRef {
public:
TypeReporter() {}
explicit TypeReporter(ObjectPtr
它重载了运算符"->",所以:
reporter->Assign(types[1], TensorType(data->shape, data->dtype));
首先会实例化一个TensorType对象,因为我们的例子是对某一个维度的数据取绝对值,所以输出的数据shape及dtype与输入相同。然后通过reporter->Assign()调用class TypeReporterNode中纯虚函数virtual void Assign(dst, src) = 0,将创建好的TensorType对象赋值给输出TensorType,即types[1]。
3、关联算子的参数数目、属性
这一步的操作,为自定义算子注册算子名称,通过调用接口增加算子注释。这里需要用到C++的宏RELAY_REGISTER_OP,涉及的参数含义如下:
Arity(参数数量)
位置参数的名称和描述
支持级别(1 表示内部实现;较高的数字表示较少的内部支持或外部支持的算子)
算子的类型关系
优化算子时有用的其他注释。
需要在src/relay/op/tensor/transform.cc中注册算子并设置相关属性:
RELAY_REGISTER_OP("axis_abs")
.describe(R"doc(Computes the axis abs of a tensor.)doc") TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor")
.set_support_level(3)
.add_type_rel("axis_abs", AxisAbsRel)
.set_attr("TOpPattern", kOpaque);
Q:宏RELAY_REGISTER_OP做了什么?
A:RELAY_REGISTER_OP用于注册Relay算子:
#define TVM_STR_CONCAT_(__x, __y) __x##__y
#define TVM_STR_CONCAT(__x, __y) TVM_STR_CONCAT_(__x, __y)
#define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_##Op
#define TVM_REGISTER_OP(OpName) \
TVM_STR_CONCAT(TVM_OP_REGISTER_VAR_DEF, __COUNTER__) = \
::tvm::OpRegEntry::RegisterOrGet(OpName).set_name()
展开为:
static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_Op0=::tvm::OpRegEntry::RegisterOrGet(OpName).set_name()
其中COUNTER为编译器内置宏,初值是0,每预编译一次其值自己加1,通常配合 ## 使用,用于构建唯一的标识符,做法其实很简单,把任意一个标识符与 COUNTER 合并就可以了:
#define STR_CONCAT_(x, y) x##y // 合并用的宏
#define STR_CONCAT(x, y) STR_CONCAT_(x, y) // 因为 ## 的特性 ( 阻止另一个宏的展开 ),需要中间层
#define UNIQUE_NAME(name) STR_CONCAT(name, __COUNTER__) // 把标识符与 __COUNTER__合并, 就可以建立唯一的变数名称了
而::tvm::OpRegEntry::RegisterOrGet(OpName)通过算子名称在全局的算子注册机对象中查找算子并返回OpRegEntry对象:
OpRegEntry& OpRegEntry::RegisterOrGet(const String& name) {
return OpRegistry::Global()->RegisterOrGet(name);
}
Q:类OpRegEntry定义了什么?
A:类的定义在include/tvm/ir/op.h:
class OpRegEntry {public:
const Op& op() const { return op_; }
inline OpRegEntry& describe(const std::string& descr);
inline OpRegEntry& add_argument(const std::string& name, const std::string& type,
const std::string& description);
inline OpRegEntry& add_type_rel(const std::string& rel_name,
runtime::TypedPackedFunc&, int, const Attrs&, const TypeReporter&)> type_rel_func);
template
inline OpRegEntry& set_attrs_type();
inline OpRegEntry& set_attrs_type_key(const String& key);
inline OpRegEntry& set_num_inputs(int32_t n);
inline OpRegEntry& set_support_level(int32_t level);
template
inline OpRegEntry& set_attr(const std::string& attr_name,
const ValueType& value, int plevel = 10);
inline void reset_attr(const std::string& attr_name);
inline OpRegEntry& set_name() {...}
TVM_DLL static OpRegEntry& RegisterOrGet(const String& name);
private:
template
friend class AttrRegistry;
std::string name;
Op op_;
TVM_DLL OpRegEntry(uint32_t reg_index);
inline OpNode* get()
TVM_DLL void UpdateAttr(const String& key, runtime::TVMRetValue value, int plevel);
};
可以看到,大部分成员函数都是返回自身的指针,从而方便链式调用,它们的实现代码与定义在同一个头文件中。其中的get()私有成员函数会返回OpNode指针,其它成员函数通过get()来获取算子节点的指针:
inline OpNode* OpRegEntry::get() { return const_cast(op_.operator->()); }
Q:链式调用中的几个函数作用是什么?
A:链式调用中的几个函数都是对OpNode节点对象的成员进行赋值,所以需要了解class OpNode的定义:
class OpNode : public RelayExprNode {
public:
String name; // 算子的名称
mutable FuncType op_type; // 算子的类型
String description; // 算子的具体描述,可以用在自动生成说明文档
Array arguments; // 算子的输入参数信息
String attrs_type_key; // 属性字段的类型键值,可以为空
uint32_t attrs_type_index{0}; // 属性的类型索引
int32_t num_inputs = -1; // 算子输入参数个数,-1表示可变长
int32_t support_level = 10; // 算子的支持等级,值越低优先级越高。void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("op_type", &op_type);
v->Visit("description", &description);
v->Visit("arguments", &arguments);
v->Visit("attrs_type_key", &attrs_type_key);
v->Visit("num_inputs", &num_inputs);
v->Visit("support_level", &support_level);
}
...
static constexpr const char* _type_key = "Op";
TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelayExprNode);
private:
...
};
所以在链式调用的函数中,只做了简单赋值的函数是:
describe()就是给OpNode->description赋值;
set_num_inputs()是设置输入参数个数;
set_support_level()是设置支持等级;
add_argument()为arguments数组添加元素;
set_attr<>()会调用class AttrRegistry中的UpdateAttr()方法进行属性更新。
其中,对于add_argument(),因为TVM将每个算子的参数都用AttrFieldInfo描述,而AttrFieldInfo实际的内容是AttrFieldInfoNode:
class AttrFieldInfoNode : public Object {
public:
String name; // 字段名称
String type_info; // 类型说明
String description; // 详细描述
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("type_info", &type_info);
v->Visit("description", &description);
}
static constexpr const char* _type_key = "AttrFieldInfo";
static constexpr bool _type_has_method_sequal_reduce = false;
static constexpr bool _type_has_method_shash_reduce = false;
TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object);
};
所以add_argument()在赋值前,会创建一个AttrFieldInfoNode对象再把它放入到arguments数组中:
inline OpRegEntry& OpRegEntry::add_argument(const std::string& name, const std::string& type,
const std::string& description) {
auto n = make_object();
n->name = name;
n->type_info = type;
n->description = description;
get()->arguments.push_back(AttrFieldInfo(n));
return *this;
}
需要详细说明的是.add_type_rel() 和.set_attr()。
Q:add_type_rel()流程是怎么样的?
A:在函数中会创建输入与输出的TypeVarNode,然后创建TypeRelationNode将类型关系函数管理起来,并定义一个FuncTypeNode将这些定义好的对象作为输入,最终赋值给op_type。
inline OpRegEntry& OpRegEntry::add_type_rel(
const std::string& rel_name,
runtime::TypedPackedFunc&, int, const Attrs&, const TypeReporter&)>
type_rel_func) {
auto func_name = std::string("tvm.relay.type_relation.") + rel_name;
TypeRelationFn env_type_rel_func;
if (runtime::Registry::Get(func_name)) {
auto env_func = EnvFunc::Get(func_name);
env_type_rel_func = env_func;
} else {
runtime::Registry::Register(func_name).set_body(type_rel_func.packed()); // 创建registy注册type_rel_func
auto env_func = EnvFunc::Get(func_name);
env_type_rel_func = env_func; // 这个跟第二小节定义的类型关系函数相关联
}
Array type_params; // TypeVar是Type的子类
Array arg_types;
// Add inputs.
std::string input_name_prefix = "in";
for (int i = 0; i < get()->num_inputs; i++) {
auto name = input_name_prefix + std::to_string(i);
auto param = TypeVar(name, TypeKind::kType); // 创建一个TypeVarNode对象
type_params.push_back(param);
arg_types.push_back(param);
}
Array ty_call_args = arg_types;
// Add output type.
auto out_param = TypeVar("out", TypeKind::kType);
type_params.push_back(out_param);
ty_call_args.push_back(out_param);
TypeConstraint type_rel =
TypeRelation(env_type_rel_func, ty_call_args, arg_types.size(), Attrs());// 创建TypeRelationNode
auto func_type = FuncType(arg_types, out_param, type_params, {type_rel}); // 创建FuncTypeNode
get()->op_type = func_type; // 对op_type成员赋值
return *this;
}
TypeRelation()会创建一个TypeRelationNode,它实际上保存了之前定义的类型关系函数的相关信息:
class TypeRelationNode : public TypeConstraintNode {
public:
TypeRelationFn func;
Array args; // The type arguments to the type function.
int num_inputs; // Number of inputs arguments
Attrs attrs; // Attributes to the relation function
void VisitAttrs(AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("args", &args);
v->Visit("num_inputs", &num_inputs);
v->Visit("attrs", &attrs);
v->Visit("span", &span);
}
...
}
FuncType()创建FuncTypeNode,将定义的输入、输出、参数类型和类型关系节点作为输入:
class FuncTypeNode : public TypeNode {
public:
Array arg_types; // type type of arguments
Type ret_type; // The type of return value
Array type_params; // The type parameters of the function
Array type_constraints; // potential constraint the type need to obey
void VisitAttrs(AttrVisitor* v) {
v->Visit("arg_types", &arg_types);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("type_constraints", &type_constraints);
v->Visit("span", &span);
}
...
}
4、算子compute实现
有两种方式实现算子计算过程:
(1)python端实现计算
在python/tvm/topi/transform.py添加:
def axis_abs(x, axis, indice):
"""Take absolute value of the input of axis in x, element-wise.
Parameters
----------
x : tvm.te.Tensor
Input argument.
axis: int
Input argument.
indice: int
Input argument.
Returns
-------
y : tvm.te.Tensor
The result.
"""
ishape = x.shape
assert len(ishape) == 3
assert indice < get_const_int(ishape[axis])
assert indice >= 0
if axis == 0:
return te.compute(x.shape, lambda i,j,k: te.if_then_else(x[i,j,k] >= 0, x[i,j,k],
te.if_then_else(i == indice, -x[i,j,k], x[i,j,k])))
elif axis == 1:
return te.compute(x.shape, lambda i, j, k: te.if_then_else(x[i, j, k] >= 0, x[i, j, k],
te.if_then_else(j == indice, -x[i, j, k], x[i, j, k])))
else:
return te.compute(x.shape, lambda i, j, k: te.if_then_else(x[i, j, k] >= 0, x[i, j, k],
te.if_then_else(k == indice, -x[i, j, k], x[i, j, k])))
并且在python/tvm/relay/op/_transform.py中设置算子计算函数属性:
@_reg.register_compute("axis_abs") # 设置算子的计算函数属性,默认的level为10
def compute_axis_abs(attrs, inputs, output_type):
"""Compute definition of axis_abs"""
return topi.axis_abs(inputs[0], attrs.axis, attrs.indice)
(2)C++端实现计算
在src/relay/op/tensor/transform.cc添加算子计算函数:
Array AxisAbsCompute(const Attrs& attrs, const Array& inputs,
const Type& out_type) {
// TODO
}
并且调用RELAY_REGISTER_OP("axis_abs")注册算子时需要设置它的计算函数属性:
.set_attr("FTVMCompute", AxisAbsCompute)
此时在python/tvm/topi/transform.py中的算子实现可以直接调用cpp的代码:
def axis_abs(x, axis, indice):
"""Take absolute value of the input of axis in x, element-wise.
Parameters
----------
x : tvm.te.Tensor
Input argument.
axis: int
Input argument.
indice: int
Input argument.
Returns
-------
y : tvm.te.Tensor
The result.
"""
return cpp.axis_abs(x, axis, indice)
5、注册算子的compute、schedule
在实现了算子compute逻辑以后,需要与我们实现的算子接口绑定在一起。在TVM中,这就需要不仅实现算子的compute接口,还要实现对应的schedule。而strategy就是对compute选择合适的schedule。需要在python/tvm/relay/op/strategy/generic.py添加算子的计算策略:
def wrap_compute_axis_abs(topi_compute):
"""Wrap axis_abs topi compute"""
def _compute_axis_abs(attrs, inputs, _):
return [topi_compute(inputs[0], attrs.axis, attrs.indice)]
return _compute_axis_abs
@override_native_generic_func("axis_abs_strategy")
def axis_abs_strategy(attrs, inputs, out_type, target):
"""axis_abs generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_axis_abs(topi.axis_abs),
wrap_topi_schedule(topi.generic.schedule_injective),
name="axix_abs.generic",
)
return strategy
在python/tvm/relay/op/_transform.py中将算子与计算策略关联起来:
_reg.register_strategy("axis_abs", strategy.axis_abs_strategy)
6、为算子生成调用节点并注册 API hook
现在有一个可以调用的relay算子,下一步就是如何通过relay call node调用。这就需要实现一个函数,传递相应的参数给对应的relay算子,并且返回对应算子的Call Node(这个算子最终在Relay表达式的AST里面)。需要在src\relay\op\tensor\transform.cc添加:
Expr MakeAxisAbs(Expr data, int axis, int indice) {
auto attrs = make_object();
attrs->axis = axis;
attrs->indice = indice;
static const Op& op = Op::Get("axis_abs");
return Call(op, {data}, Attrs(attrs), {}); // 会创建一个CallNode实例
}
TVM_REGISTER_GLOBAL("relay.op._make.axis_abs").set_body_typed(MakeAxisAbs);
Q:Call Node是什么?
A:CallNode类是ExprNode的子类,它在程序调用Call函数时被实例化:
class CallNode : public ExprNode {
protected:
Object::FDeleter saved_deleter_;
static void Deleter_(Object* ptr);
public:
Expr op; // 算子的计算表达函数
tvm::Array args; // call函数的输入参数
Attrs attrs; // 属性
tvm::Array type_args; // 传递给多态(模板)函数的类型参数
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("op", &op);
v->Visit("args", &args);
v->Visit("attrs", &attrs);
v->Visit("type_args", &type_args);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const CallNode* other, SEqualReducer equal) const {...}
void SHashReduce(SHashReducer hash_reduce) const {...}
static constexpr const char* _type_key = "relay.Call";
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
template
friend class runtime::ObjAllocatorBase;
friend class Call;
};
Call::Call(Expr op, Array args, Attrs attrs, Array type_args, Span span) {
ObjectPtr n = make_object(); // 创建CallNode
n->op = std::move(op);
n->args = std::move(args);
n->attrs = std::move(attrs);
n->type_args = std::move(type_args);
n->span = std::move(span);
data_ = std::move(n);
}
7、将Python API hook 封装成简洁的调用方式
为更方便的使用,通常的做法是构造单独的函数,因此最好封装成更简洁的python接口,需要在python/tvm/relay/op/transform.py中添加:
def axis_abs(data, axis=0, indice=0):
"""Computes abs of data along a certain axis indice.
Parameters
----------
data : relay.Expr
The source data to be invert permuated.
Returns
-------
ret : relay.Expr
Invert permuated data. Has the same type as data.
"""
return _make.axis_abs(data, axis, indice)
8、为新的relay 算子编写测试用例
需要在tvm/tests/python/test_op_level3.py添加:
class TestAxisAbs:
dshape, axis, indice = tvm.testing.parameters( # 定义测试用例参数,这里是输入tensor的shape,axis和indice
((4, 4, 1), 1, 1),
((4, 4, 1), 0, 1),
((3, 3, 3), 1, 1),
)
def test_axis_abs(self, dshape, axis, indice):
x = relay.var("x", relay.TensorType(dshape, "int32")) # 定义relay输入tensor
y = relay.axis_abs(x, axis=axis, indice=indice) # 定义axis_abs运算表达式
yy = run_infer_type(y) # 推理运算表达式的类型,定义在python/tvm/relay/testing/__init__.py
assert yy.checked_type == relay.TensorType(dshape, "int32") # 类型测试
data = np.random.randint(-5, 5, size=dshape).astype("int32")
op_res = create_executor().evaluate(y, {x: relay.const(data)}) # 创建执行器并执行算子推理
if axis == 0:
data[indice,:,:] = np.abs(data[indice,:,:])
elif axis == 1:
data[:,indice, :] = np.abs(data[:,indice,:])
else:
data[:,:,indice] = np.abs(data[:,:,indice])
ref_res = data
np.testing.assert_equal(op_res.numpy(), ref_res) # 对比numpy结果与relay的计算结果
如果没有安装pytest,要先安装pytest,再运行测试用例:
pip install pytest && cd tvm/tests/python && pytest relay/test_op_level3.py::TestAxisAbs
用例通过的结果如下:
三、总结
本文根据TVM官方文档给出的步骤添加了一个自定义算子,并对过程中有疑问的地方做了一些说明。