TVM的Object类是很多类的基类,详细的分析资料可以参考
深入理解TVM:Object家族 - 知乎
深入理解TVM:Object家族(二) - 知乎
TVM源码品读:万物基石——Object类(1) - 知乎
TVM源码品读:万物基石——Object(2) - 知乎
在阅读TVM C++代码的时候,有很多Object的派生类的类型转换需要追溯到Object/ObjectPtr/ObjectRef,所以这里着重分析三者之间的关系。我们可以只保留三者的包含关系代码:
class TVM_DLL Object {
public:
...
protected:
...
private:
...
friend class ObjectPtr;
...
};
template
class ObjectPtr {
public:
...
private:
Object* data_{nullptr};
...
friend class Object;
friend class ObjectRef;
...
};
class ObjectRef {
public:
...
protected:
ObjectPtr
从上面的代码可以看到,ObjectPtr的数据成员data_是一个Object指针,ObjectRef的数据成员data_是一个 ObjectPtr实例。
下面分析下几个频繁出现的类方法和接口函数
template
inline bool IsInstance() const;
判断当前实例的是不是目标类型TaregtType的实例,返回true有以下几种场景:
1.当前实例类型和TargetType是同一种类型;
2.如果当前类型是TargetType的子类;
3.如果当前实例的类型和TargetType有共同的祖先,并且当前类型和祖先的距离更远。一个形象的比喻就是,一个人的子孙是这个人的兄弟的实例,这个人的兄弟不是这个人的子孙的实例。
例如Function类型的继承链(从左到右,是从子孙到祖先):
Function, BaseFunc, RelayExpr, BaseExpr, ObjectRef
FunctionNode, BaseFuncNode, RelayExprNode, BaseExprNode, Object
现在有一个变量 BaseFuncNode base_func_node, 那么base_func_node.as
template
inline const ObjectType* ObjectRef::as() const {
if (data_ != nullptr && data_->IsInstance()) {
return static_cast(data_.get());
} else {
return nullptr;
}
}
分析代码:
1. 这个方法的实现是在ObjectRef类中,各子类在调用的时候都是调用的这个实现。所以后面分析中只说Object、ObjectPtr、ObjectRef,但是同样适用子类;
2. ObjectRef::data_是ObjectPtr
3. data_.get()是ObjectPtr::get(),方法返回的是ObjectPtr的data_成员(注意不是2中的那个data_),该成员是一个Object指针。所以data_.get()返回的是一个Object(或者Object的子类的)指针;
4. 使用static_cast强转Object指针到ObjectType*,那么这个Object指针指向的类必须是ObjectType或者ObjectType的子类在。即要求ObjectRef指向的是ObjectType,或者是ObjectType的子类。
简单的说,就是把当前对象实例(是一个引用类型)转换为祖先类指针。
例如Function类型的继承链(从左到右,是从子孙到祖先):
Function, BaseFunc, RelayExpr, BaseExpr, ObjectRef
FunctionNode, BaseFuncNode, RelayExprNode, BaseExprNode, Object
现在有一个变量 BaseFunc base_func, 那么base_func.as
GetRef
接口代码实现:
template
inline RefType GetRef(const ObjType* ptr) {
static_assert(std::is_base_of::value,
"Can only cast to the ref of same container type");
if (!RefType::_type_is_nullable) {
ICHECK(ptr != nullptr);
}
return RefType(ObjectPtr(const_cast(static_cast(ptr))));
}
这里我们忽略GetRef函数一开始的检查,只看最后return的那一句。
const_cast
ObjectPtr
explicit ObjectPtr(Object* data) : data_(data) {
if (data != nullptr) {
data_->IncRef();
}
}
IncRef是增加Object的引用次数,这里不细究
接下来RefType(xxx)是调用RefType的构造函数,参数为 ObjectPtr类型,生成一个RefType类实例。如果这个RefType就是ObjectRef,看下对应的构造函数:
explicit ObjectRef(ObjectPtr data) : data_(data) {}
这样就生成了一个ObjectRef类型。如果RefType是ObjectRef的子类,并且是ObjType对应的类型(比如IRModule和IRModuleNode),那么就可以由某个类型的指针类型,得到对应的Ref类型了。
这种转换在代码中很多,比如说IRModule::FromExprInContext中:
if (auto* func_node = expr.as()) {
func = GetRef(func_node);
这里将BaseFuncNode* func_node包装为引用类型BaseFunc
这个函数是将一个较为基础的类型的引用类型,包装为它的某个子类型。
template
inline SubRef Downcast(BaseRef ref) {
if (ref.defined()) {
ICHECK(ref->template IsInstance())
<< "Downcast from " << ref->GetTypeKey() << " to " << SubRef::ContainerType::_type_key
<< " failed.";
} else {
ICHECK(SubRef::_type_is_nullable) << "Downcast from nullptr to not nullable reference of "
<< SubRef::ContainerType::_type_key;
}
return SubRef(std::move(ref.data_));
}
我们先假定SubRef和BaseRef都是ObjectRef类型或者ObjectRef的子类。这里首先要求ref类型是SubRef类型的实例(IsInstance), 然后返回ref.data_的SubRef类型实例。ref.data_是BaseRef的Ptr类型。这里我们可以先看下ObjectRef的构造函数:
explicit ObjectRef(ObjectPtr data) : data_(data) {}
这里直接给 data_赋值,没有其他操作。
也就是这种转换,输入的实例虽然被看作是一个较为基础的类型,但是实际上它是SubRef的类型或者SubRef的子类实例?否则这种Downcast不太合理吧。