TVM源码中涉及到表达式遍历的地方,一般是使用VisitExpr接口进行。这个接口涉及TVM的visitor模式,具体分析可以参考TVM之设计模式解读(一)--visitor模式
使用visitor遍历的起点是调用VisitExpr接口。我们看下基类tvm::relay::ExprFunctor中这个方法的代码:
template
class ExprFunctor {
private:
...
using FType = tvm::NodeFunctor;
public:
/*!
* \brief The functor call.
* \param n The expression node.
* \param args Additional arguments.
* \return The result of the call
*/
virtual R VisitExpr(const Expr& n, Args... args) {
ICHECK(n.defined()) << "Found null pointer node while traversing AST. The previous pass may "
"have generated invalid data.";
static FType vtable = InitVTable();
return vtable(n, this, std::forward(args)...);
}
...
}
VisitExpr中调用InitVTable,这个代码展开后:
template
class ExprFunctor {
private:
...
public:
...
private:
static FType InitVTable() {
FType vtable;
vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) {
return self->VisitExpr_(static_cast(n.get()), std::forward(args)...); });;
vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) {
return self->VisitExpr_(static_cast(n.get()), std::forward(args)...); });;
vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args) {
return self->VisitExpr_(static_cast(n.get()), std::forward(args)...); });;
vtable.template set_dispatch([](const ObjectRef& n, TSelf* self, Args... args)
{ return self->VisitExpr_(static_cast(n.get()), std::forward(args)...); });;
...
return vtable;
}
};
...
}
template
class NodeFunctor {
private:
...
/*! \brief internal function table */
std::vector func_;
...
public:
...
/*!
* \brief set the dispacher for type TNode
* \param f The function to be set.
* \tparam TNode the type of Node to be dispatched.
* \return reference to self.
*/
template
TSelf& set_dispatch(FPointer f) { // NOLINT(*)
uint32_t tindex = TNode::RuntimeTypeIndex();
if (func_.size() <= tindex) {
func_.resize(tindex + 1, nullptr);
}
ICHECK(func_[tindex] == nullptr) << "Dispatch for " << TNode::_type_key << " is already set";
func_[tindex] = f;
return *this;
}
InitVTable中调用NodeFunctor::set_dispatch接口,类型参数为tvm relay ir的各种表达式类型,传入set_dispatch的函数参数是lamad函数,lamad函数体中执行self->VisitExpr_()。self时传入的参数this,当从派生类中发起VisitExpr的时候,这个this将是派生类实例,而不是基类。
NodeFunctor::set_dispatch是在函数指针表func_中添加传入的lamad函数,表项索引为类型参数的id。tvm中类型id设计分析可以参考
深入理解TVM:Object家族
InitVTable在为所有类型都调用set_dispatch注册对应的visit调用后,返回了注册的NodeFunctor实例。而VisitExpr在调用InitVTable后return vtable(n, this, std::forward
R operator()(const ObjectRef& n, Args... args) const {
ICHECK(can_dispatch(n)) << "NodeFunctor calls un-registered function on type "
<< n->GetTypeKey();
return (*func_[n->type_index()])(n, std::forward(args)...);
}
这里以传入的参数的类型id为索引,从func_表中获取对应的lamad函数体,并调用执行。也就是执行了类实例的VisitExpr_。因为一般来说发起VisitExpr调用的是以tvm::relay::ExprFunctor为基类,并在VisitExpr_中完成业务操作的类,所以这里VisitExpr_是调用的业务类中重载后的VisitExpr_方法。业务类对自己关注的类型的VisitExpr_进行重载,在其中完成自己的操作。
如果派生类不对各种类型重载VisitExpr_,就会调用到tvm::relay::ExprFunctor定义的VisitExpr_,抛出异常:
virtual R VisitExpr_(const ConstantNode* op, Args... args) {
return VisitExprDefault_(op, std::forward(args)...);
};
virtual R VisitExpr_(const TupleNode* op, Args... args) {
return VisitExprDefault_(op, std::forward(args)...);
};
virtual R VisitExpr_(const VarNode* op, Args... args) {
return VisitExprDefault_(op, std::forward(args)...);
};
...
virtual R VisitExprDefault_(const Object* op, Args...) {
::tvm::runtime::detail::LogFatal("/home/tvm/tvmsource/tvm/include/tvm/relay/expr_functor.h", 114).stream() << "Do not have a default for " << op->GetTypeKey();
throw;
}
ExprVisitor继承了ExprFunctor,并对VisitExpr和VisitExpr_进行了重载:
class ExprVisitor : public ::tvm::relay::ExprFunctor {
public:
void VisitExpr(const Expr& expr) override;
void VisitExpr_(const VarNode* op) override;
void VisitExpr_(const GlobalVarNode* op) override;
void VisitExpr_(const ConstantNode* op) override;
void VisitExpr_(const TupleNode* op) override;
void VisitExpr_(const FunctionNode* op) override;
void VisitExpr_(const CallNode* op) override;
void VisitExpr_(const LetNode* op) override;
void VisitExpr_(const IfNode* op) override;
void VisitExpr_(const OpNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;
void VisitExpr_(const RefCreateNode* op) override;
void VisitExpr_(const RefReadNode* op) override;
void VisitExpr_(const RefWriteNode* op) override;
void VisitExpr_(const ConstructorNode* op) override;
void VisitExpr_(const MatchNode* op) override;
virtual void VisitType(const Type& t);
virtual void VisitClause(const Clause& c);
virtual void VisitPattern(const Pattern& c);
virtual void VisitSpan(const Span& span);
protected:
// Internal visiting counter
std::unordered_map visit_counter_;
};
void ExprVisitor::VisitExpr(const Expr& expr) {
auto it = visit_counter_.find(expr.get());
if (it != visit_counter_.end()) {
++it->second;
} else {
using TParent = ExprFunctor;
TParent::VisitExpr(expr);
visit_counter_.insert({expr.get(), 1});
}
}
visit_counter_表记录了每个表达式(注意不是每种)的访问历史。在VisitExpr中,如果发现该表达式已经访问过,则只是递增该表达式的访问计数,而不做实质的访问操作。如果发现表达式没有遍历过,则调用基类ExprFunctor的VisitExpr,进而调用到发起VisitExpr的某个派生类的VisitExpr_。
派生类ExprMutator的定义跟ExprFunctor差不多:
class ExprMutator : public ::tvm::relay::ExprFunctor {
public:
/*!
* \brief Mutate is alias for VisitExpr
* \return expr.
*/
Expr Mutate(const Expr& expr) { return this->VisitExpr(expr); }
Expr VisitExpr(const Expr& expr) override;
Expr VisitExpr_(const VarNode* op) override;
Expr VisitExpr_(const ConstantNode* op) override;
Expr VisitExpr_(const GlobalVarNode* op) override;
Expr VisitExpr_(const OpNode* op) override;
Expr VisitExpr_(const TupleNode* op) override;
Expr VisitExpr_(const FunctionNode* op) override;
Expr VisitExpr_(const CallNode* call_node) override;
Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override;
Expr VisitExpr_(const RefCreateNode* op) override;
Expr VisitExpr_(const RefReadNode* op) override;
Expr VisitExpr_(const RefWriteNode* op) override;
Expr VisitExpr_(const ConstructorNode* op) override;
Expr VisitExpr_(const MatchNode* op) override;
/*!
* \brief Used to visit the types inside of expressions.
*
* Can be overloaded to transform the types in arbitrary
* ways, one way would be to define a sub-class of type
* visitor for types which transform them appropriately.
*/
virtual Type VisitType(const Type& t);
virtual Clause VisitClause(const Clause& c);
virtual Pattern VisitPattern(const Pattern& c);
protected:
/*! \brief Internal map used for memoization. */
std::unordered_map memo_;
};
Expr ExprMutator::VisitExpr(const Expr& expr) {
auto it = this->memo_.find(expr);
if (it != this->memo_.end()) {
return it->second;
} else {
Expr new_expr = ExprFunctor::VisitExpr(expr);
memo_[expr] = new_expr;
return new_expr;
}
}
这里需要注意的是,ExprMutator的VisitExpr和VisitExpr_都是有返回值的,调用将返回遍历到的表达式,这样可以在VisitExpr_外对表达式做操作,比如说修改。
GraphPlanMemory分配流程中涉及的类关系图如下所示:
在该流程中分别从StorageAllocInit和StorageAllocator里面调用Run接口,Run接口调用VisitExpr,这个时候调用的是ExprVisitor::VisitExpr。而VisitExpr_则是调用的StorageAllocaBaseVisitor和DeviceAwareExprVisitor中重载的。
从这里也可以看到,ExprFunctor和ExprVisitor是纯粹作为visitor模式的实现而设计,具体的业务在各业务实现类中。