本文主要是从 torch.jit.trace 接口,打开 Pytorch JIT 的大门,介绍在正常 nn.module 经过 Trace 之后形成 scriptModule 的过程和涉及到的 C++ 类,因为涉及到的内容蛮多的,所以这里就从源码的角度介绍了一些相对比较重要,或者是可以帮助我们理解的类。
在开始学习之前,不知道大家有没有想过一个看似简单但是也未必那么好回答的问题。编译语言为什么要分数据类型?
其实查找的大概意思就是,在计算机内部,为了实现不同的功能,会有不同的计算电路,对于这些不同电路对应到上层就是不同的数据类型。所以接触变成语言最基本的就是要了解它的数据类型。
首先最重要的是类型。所以在梳理 jit 代码中,最后开始就落到了 torch::jit::Type 这个类上,这个基类表示不同的类型,目前类型可以表示的类型共 33 种,也包含了 IValue 的 Tag 中的类型。
// torch/include/ATen/core/jit_type_base.h
#define C10_FORALL_TYPES(_) \
_(AnyType) \
_(EnumType) \
_(AnyEnumType) \
_(TensorType) \
_(StorageType) \
_(TupleType) \
_(ListType) \
_(DictType) \
_(NumberType) \
_(FloatType) \
_(ComplexType) \
_(FutureType) \
_(RRefType) \
_(IntType) \
_(NoneType) \
_(StringType) \
_(GeneratorType) \
_(QuantizerType) \
_(BoolType) \
_(OptionalType) \
_(VarType) \
_(DeviceObjType) \
_(StreamObjType) \
_(FunctionType) \
_(ClassType) \
_(PyObjectType) \
_(CapsuleType) \
_(InterfaceType) \
_(QSchemeType) \
_(LayoutType) \
_(ScalarTypeType) \
_(AnyListType) \
_(AnyTupleType) \
_(AnyClassType)
enum class TypeKind {
#define DEFINE_TYPE(T) T,
C10_FORALL_TYPES(DEFINE_TYPE)
#undef DEFINE_TYPE
};
Type 类型主要是定义了上面可 TypeKind 枚举,然后可以查询到相关的打印 name,是否是 SubType,是否是 module ,各类 cast 函数,然后就是获取数组中的 Type。下面是部分源码:
struct TORCH_API Type : std::enable_shared_from_this {
private:
TypeKind kind_;
protected:
Type(TypeKind kind) : kind_(kind) {}
public:
virtual bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const;
virtual std::string str() const = 0;
template
std::shared_ptr cast() {
if (T::Kind == kind()) {
return std::static_pointer_cast(shared_from_this());
}
return nullptr;
}
virtual at::ArrayRef containedTypes() const {
return {};
}
};
在 pytorch/pytorch/torch/include/ATen/core/jit_type.h 路径下,对 Type 的各种子类进行了实现,这里比较重要的就是 TensorType 和 ClassType,鉴于前面已经介绍了关于 Tensor 的内容,这里就先不展开说了,重点介绍一下 ClassType。
类中定义了一个辅助类 ClassAttribute,类属性 ClassAttribute 主要是围绕了:name,kind(AttributeKind),Type(TypePtr),所以有了下面的类定义:
struct TORCH_API ClassAttribute {
public:
ClassAttribute(AttributeKind kind,
TypePtr attributeType,
std::string attributeName) :
kind_(kind),
attributeType_(attributeType),
attributeName_(std::move(attributeName)) {}
AttributeKind getKind() const {
return kind_;
}
TypePtr getType() const {
return attributeType_;
}
const std::string& getName() const {
return attributeName_;
}
private:
AttributeKind kind_;
TypePtr attributeType_;
std::string attributeName_;
};
其中:
enum class AttributeKind {
BUFFER,
PARAMETER,
REGULAR_ATTRIBUTE
};
classType 的类定义也很长,下面罗列出相关内容:
struct TORCH_API ClassType : public NamedType {
// Create a class type with name `name` and its methods stored in `cu`.
static ClassTypePtr create(
c10::optional qualifiedName,
std::weak_ptr cu,
bool is_module = false,
std::string doc_string = "",
std::vector unresolved_class_attributes = {});
const std::vector& methods() const;
std::string str() const override
bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override;
...
//—————————————————————— 属性的操作 ————————————————————————
TypePtr findAttribute(const std::string& name) const;
TypePtr getAttribute(const std::string& name) const;
size_t numAttributes() const;
size_t addOrCheckAttribute(
const std::string& name,
TypePtr ty,
bool is_parameter = false,
bool is_buffer = false)
...
//—————————————————————— 常量节点的操作 ————————————————————————
size_t addConstant(const std::string& name, const IValue& value);
IValue getConstant(const std::string& name) const;
at::ArrayRef constantValues() const;
...
//—————————————————————— 函数相关的操作 ————————————————————————
void addForwardPreHook(torch::jit::Function* pre_hook_ptr);
void addForwardHook(torch::jit::Function* hook_ptr);
const std::vector& getForwardHooks() const;
const std::vector& getForwardPreHooks() const;
void addMethod(torch::jit::Function* method);
torch::jit::Function* findMethod(const std::string& name) const;
torch::jit::Function& getMethod(const std::string& name) const;
torch::jit::Function* findHook(const std::string& name) const;
private:
std::vector constantNames_;
std::vector constantValues_;
std::vector attributes_;
std::vector attributeTypes_;
std::vector methods_;
std::vector staticmethods_;
std::vector forward_hooks_;
std::vector forward_pre_hooks_;
...
};
其中 QualifiedName 类用来代表形如:foo.bar.baz 这种格式的名字。 CompilationUnit 类可以看做是带名字的函数的 List,里面存储了类的函数,并提供了相关接口去遍历和调用相关函数。可以看到一个 ClassType 提供的类似于真实的类,除了继承的 Type 类的共用方法,主要成员函数集中在 属性(ClassAttribute),类的方法(Method,hook,prehook),还有常量节点(constant)的相关操作,最后就是构造时需要有一个 compilation_unit。
前面定义了表示类的类型 ClassType,有了类类型可以创建的对象了。首先要介绍的是定义在 ivalue_inl.h 中的 c10::ivalue::Object,不是很长,下面先贴出类定义:
// torch/include/ATen/core/ivalue_inl.h
struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
public:
Object(StrongTypePtr type, size_t numSlots) : type_(std::move(type)) {
slots_.resize(numSlots);
}
static c10::intrusive_ptr
其中 StrongTypePtr 表示:
// torch/include/ATen/core/ivalue.h
struct TORCH_API StrongTypePtr {
StrongTypePtr(
std::shared_ptr cu,
std::shared_ptr type);
std::shared_ptr cu_;
std::shared_ptr type_;
};
可以看到 ivalue::Object 类中主要是定义了一个 vector
然后就是torch::jit::Object 类了。
// pytorch/pytorch/torch/csrc/jit/api/object.h
using ObjectPtr = c10::intrusive_ptr;
struct TORCH_API Object {
Object() = default;
Object(ObjectPtr _ivalue) : _ivalue_(std::move(_ivalue)) {}
Object(std::shared_ptr cu, const c10::ClassTypePtr& type);
Object(
c10::QualifiedName,
std::shared_ptr cu,
bool shouldMangle = false);
ObjectPtr _ivalue() const;
c10::ClassTypePtr type() const {
return _ivalue()->type();
}
void setattr(const std::string& name, c10::IValue v) {
if (_ivalue()->type()->hasConstant(name)) {
TORCH_CHECK(
false,
"Can't set constant '",
name,
"' which has value:",
_ivalue()->type()->getConstant(name));
} else if (auto slot = _ivalue()->type()->findAttributeSlot(name)) {
const c10::TypePtr& expected = _ivalue()->type()->getAttribute(*slot);
TORCH_CHECK(
v.type()->isSubtypeOf(expected),
"Expected a value of type '",
expected->repr_str(),
"' for field '",
name,
"', but found '",
v.type()->repr_str(),
"'");
_ivalue()->setSlot(*slot, std::move(v));
} else {
TORCH_CHECK(false, "Module has no attribute '", name, "'");
}
}
c10::IValue attr(const std::string& name) const {
if (auto r = _ivalue()->type()->findAttributeSlot(name)) {
return _ivalue()->getSlot(*r);
}
if (auto r = _ivalue()->type()->findConstantSlot(name)) {
return _ivalue()->type()->getConstant(*r);
}
std::stringstream err;
err << _ivalue()->type()->repr_str() << " does not have a field with name '"
<< name.c_str() << "'";
throw ObjectAttributeError(err.str());
}
c10::IValue attr(const std::string& name, c10::IValue or_else) const {
if (auto r = _ivalue()->type()->findAttributeSlot(name)) {
return _ivalue()->getSlot(*r);
}
if (auto r = _ivalue()->type()->findConstantSlot(name)) {
return _ivalue()->type()->getConstant(*r);
}
return or_else;
}
bool hasattr(const std::string& name) const {
return _ivalue()->type()->hasAttribute(name) ||
_ivalue()->type()->hasConstant(name);
}
// 每一个 object 都有自己的 method
Method get_method(const std::string& name) const {
if (auto method = find_method(name)) {
return *method;
}
AT_ERROR("Method '", name, "' is not defined.");
}
const std::vector get_methods() const {
return c10::fmap(type()->methods(), [&](Function* func) {
return Method(_ivalue(), func);
});
}
c10::optional find_method(const std::string& basename) const;
template
IValue run_method(const std::string& method_name, Types&&... args) {
return get_method(method_name)({IValue(std::forward(args))...});
}
void define(const std::string& src, const ResolverPtr& resolver = nullptr);
size_t num_slots() const {
return _ivalue()->slots().size();
}
// 拷贝
Object copy() const;
Object deepcopy() const;
private:
mutable ObjectPtr _ivalue_;
};
torch::jit::Object 类是在 c10::ivalue::object 的基础上拓展的,例如相关拷贝实现就是直接调用的 c10::ivalue::object 实现。
可以看到这里的 torch::jit::Object 类的构造函数,其主要是由 ClassType 中指针构建出来的,cu 和 ClassTypePtr,从上面类的定义可以看出,torch::jit::Object 类主要来自 c10::ivalue::object 实现,相关操作主要来自 c10::ivalue::object 中包含的 ClassType 类接口。 _ivalue()->type() 。进行基本的 attr 和 method 的操作。
继承关系:ClassType::NamedType::Type,综上,其实 torch::jit::Object 类,主要就是通过 ivalue::Object + ClassType 来进行一些属性操作 和 Method 操作。包括对 Object 对象的深浅拷贝。 其中有一个 CompilationUnit 类型的 cu 指针,是创建对象的时候就有的,其实一个包含 Funciton 的 list。
首先要明确一下的概念:
“模块”是对某些函数或算法实现的抽象,一个 Module 类主要有以下几个属性:
与 ClassType 中的 AttributeKind 是对应的,注意当 Module clone 的时候将会深拷贝。Module 提供了 register_parameter 和 register_buffer 来分别注册两种不同的 Tensor。
Module 类是可以嵌套的,即一个 Module 可以有 submodule。
基于上面的了解,下面就进入到 torch::jit::Module 的类,来探索相关操作。
// pytorch/pytorch/torch/csrc/jit/api/module.h
struct TORCH_API Module : public Object {
// 一堆初始化函数,这里初始化函数基本包含了 Object 的初始化样式
explicit Module(c10::QualifiedName class_name);
Module(std::shared_ptr cu, const c10::ClassTypePtr& type);
Module() = default;
Module(
c10::QualifiedName,
std::shared_ptr cu,
bool shouldMangle = false);
Module(ModulePtr module_value) : Object(std::move(module_value)) {}
~Module() = default;
// 主要的 forward 函数
IValue forward(std::vector inputs) {
return get_method("forward")(std::move(inputs));
}
// 相关 Module 重要元素注册
void register_buffer(const std::string& name, at::Tensor v) {
bool is_param = false;
bool is_buffer = true;
type()->addOrCheckAttribute(name, TensorType::get(), is_param, is_buffer);
_ivalue()->setAttr(name, std::move(v));
}
void register_parameter(
const std::string& name,
at::Tensor v,
bool is_buffer) {
type()->addOrCheckAttribute(name, TensorType::get(), !is_buffer, is_buffer);
_ivalue()->setAttr(name, std::move(v));
}
// 递归应用相关 fn
void apply(const std::function& fn);
...
// 相关实际内容获取
buffer_list buffers(bool recurse = true) const;
named_buffer_list named_buffers(bool recurse = true) const;
module_list children() const; // direct modules
named_module_list named_children() const;
module_list modules() const; // all modules, including this one, recursively
named_module_list named_modules() const;
// all tensors involved in gradient optimization
parameter_list parameters(bool recurse = true) const;
named_parameter_list named_parameters(bool recurse = true) const;
// all members of the object, similar to iterating over dir(obj) in python
attribute_list attributes(bool recurse = true) const;
named_attribute_list named_attributes(bool recurse = true) const;
// dump 函数
void dump(
bool print_method_bodies,
bool print_attr_values,
bool print_param_values) const;
...
// 还包含其他的 to(), save(),copy(),deepcopy(),clone(),clone_method() 等成员函数,这里就不展开介绍了。
};
这里当我们向 torch.jit.trace 里面传入一个 nn.module,中间 .py 中会转成 script module,函数定义为:
class TestArangeModel(torch.nn.Module):
def __init__(self):
super(TestArangeModel, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3, stride=(1, 1), padding=(1, 1),bias=False)
self.conv2 = torch.nn.Conv2d(3, 3, 3, stride=(1, 1), padding=(1, 1))
def forward(self,x):
x = self.conv1(x)
x = self.conv2(x)
return x
m = TestArangeModel()
input = torch.randn(1,3,5,5)
traced = torch.jit.trace(m,input)
使用 dump 接口打印出来的传入 jit::Module 大概长这样:
module __torch__.TestArangeModel {
parameters {
}
attributes {
training = True
_is_full_backward_hook = None
conv1 = <__torch__.torch.nn.modules.conv.Conv2d object at 0x557569c9f300>
conv2 = <__torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d object at 0x557569cb4080>
}
methods {
}
submodules {
module __torch__.torch.nn.modules.conv.Conv2d {
parameters {
weight = ...
}
attributes {
weight = ...
training = True
_is_full_backward_hook = None
}
methods {
}
submodules {
}
} // 下面的 QualifiedName 就是: __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d
module __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d {
parameters {
weight = ... // parameter 和 buffer 都在 parameters 中,底层会有表示
bias = ...
}
attributes {
weight = ... // 注册的时候parameter 和 buffer也会同事加入到 attr 中,同其他属性一起
bias = ...
training = True
_is_full_backward_hook = None
}
methods { // method
}
submodules { // module 本身是可嵌套的
}
}
}
}
torch.jit.trace 主要就是通过将 torch::jit::Module 转成 torch::jit::Graph 的这样一个过程,重点就是转成 Graph 结构之后对图做的一系列变换,通过中间一个名为 TracingState 的类来追踪中间的状态,最后返回一个转换后的 torch::jit::Graph。
上面对涉及到的类进行了一些介绍,有了上面的内容,我们在本部分主要说明一下 jit.trace 接口在 C++ 部分工作的大概流程。
由上面可以看到初始的时候 Module 里面是没有 methods 的,(可能是因为是从 nn.Module 转过来的,本身类定义 nn.Module 就没有 method 相关的内容)。接下来就一步步探究如何构建成图的,包括最后 Module 会长啥样。
graph(%self : __torch__.___torch_mangle_3.TestArangeModel,
%x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
%1 : bool = prim::TracedAttr[scope="__module.training"]()
%2 : NoneType = prim::TracedAttr[scope="__module._is_full_backward_hook"]()
%3 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d = prim::TracedAttr[scope="__module.conv1"]()
%4 : Float(3, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=1, device=cpu) = prim::TracedAttr[scope="__module.conv1.weight"]()
%5 : bool = prim::TracedAttr[scope="__module.conv1.training"]()
%6 : NoneType = prim::TracedAttr[scope="__module.conv1._is_full_backward_hook"]()
%7 : __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d = prim::TracedAttr[scope="__module.conv2"]()
%8 : Float(3, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=1, device=cpu) = prim::TracedAttr[scope="__module.conv2.weight"]()
%9 : Float(3, strides=[1], requires_grad=1, device=cpu) = prim::TracedAttr[scope="__module.conv2.bias"]()
%10 : bool = prim::TracedAttr[scope="__module.conv2.training"]()
%11 : NoneType = prim::TracedAttr[scope="__module.conv2._is_full_backward_hook"]()
return ()
上述 Module 转化成 Graph 之后就是这个效果。
经过传入的 traced 函数之后,会得到:
graph(%self : __torch__.TestArangeModel,
%x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
%1 : bool = prim::TracedAttr[scope="__module.training"]()
%2 : NoneType = prim::TracedAttr[scope="__module._is_full_backward_hook"]()
%3 : __torch__.torch.nn.modules.conv.Conv2d = prim::TracedAttr[scope="__module.conv1"]()
%weight.1 : Float(3, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=1, device=cpu) = prim::TracedAttr[scope="__module.conv1.weight"]()
%5 : bool = prim::TracedAttr[scope="__module.conv1.training"]()
%6 : NoneType = prim::TracedAttr[scope="__module.conv1._is_full_backward_hook"]()
%7 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d = prim::TracedAttr[scope="__module.conv2"]()
%weight : Float(3, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=1, device=cpu) = prim::TracedAttr[scope="__module.conv2.weight"]()
%bias : Float(3, strides=[1], requires_grad=1, device=cpu) = prim::TracedAttr[scope="__module.conv2.bias"]()
%10 : bool = prim::TracedAttr[scope="__module.conv2.training"]()
%11 : NoneType = prim::TracedAttr[scope="__module.conv2._is_full_backward_hook"]()
= prim::TracedModuleForward[scope="__module.conv1"](), scope: __module.conv1
block0():
%13 : NoneType = prim::Constant(), scope: __module.conv1
%14 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%15 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%16 : int[] = prim::ListConstruct(%14, %15), scope: __module.conv1
%17 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%18 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%19 : int[] = prim::ListConstruct(%17, %18), scope: __module.conv1
%20 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%21 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%22 : int[] = prim::ListConstruct(%20, %21), scope: __module.conv1
%23 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%24 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%25 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%26 : int[] = prim::ListConstruct(%24, %25), scope: __module.conv1
%27 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%28 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%29 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%30 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%31 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%input : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=1, device=cpu) = aten::_convolution(%x, %weight.1, %13, %16, %19, %22, %23, %26, %27, %28, %29, %30, %31), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
-> ()
= prim::TracedModuleForward[scope="__module.conv2"](), scope: __module.conv2
block0():
%33 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%34 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%35 : int[] = prim::ListConstruct(%33, %34), scope: __module.conv2
%36 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%37 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%38 : int[] = prim::ListConstruct(%36, %37), scope: __module.conv2
%39 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%40 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%41 : int[] = prim::ListConstruct(%39, %40), scope: __module.conv2
%42 : bool = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%43 : int = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%44 : int = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%45 : int[] = prim::ListConstruct(%43, %44), scope: __module.conv2
%46 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%47 : bool = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%48 : bool = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%49 : bool = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%50 : bool = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%51 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=1, device=cpu) = aten::_convolution(%input, %weight, %bias, %35, %38, %41, %42, %45, %46, %47, %48, %49, %50), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
-> ()
return (%51)
接下来就是要经过各种 pass 来对当前这个 graph 做进一步的处理。要进入 pytorch/pytorch/torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp 中的相关优化 pass 函数:
void FixupTraceScopeBlocks(std::shared_ptr& graph, Module* self) {
if (self) {
ConvertTracedAttrReferences().run(graph);
} else {
for (Node* n : graph->nodes()) {
TORCH_INTERNAL_ASSERT(n->kind() != prim::TracedAttr);
}
}
MakeDefsDominateUses().run(graph->block());
convertReturnsToTuples(graph->block());
if (!self) {
// We have no Module, so we're just going to inline everything.
// This should give us a totally flat graph.
inlineScopeBlocks(graph->block());
// For TracedFork nodes
lambdaLiftBlocksAndConvertToGraph(graph->block());
runCleanupPasses(graph);
} else {
lambdaLiftBlocksAndConvertToGraph(graph->block());
createMethodCalls(graph);
runCleanupPasses(self);
// `graph` isn't referenced in `self` yet, so we need to run
// this separately
runCleanupPasses(graph);
}
}
上面函数主要思路还是处理图中节点的关系,然后再处理 prim::Tracedxxx 类型的节点,大概做了以下几个步骤:
首先递归的找到 prim::TracedModuleForward 类型节点,然后找到带有同样 scope 的 prim::TracedAttr 从输入的 module 中获得相关子 module 的节点,然后作为参数传入 prim::TracedModuleForward 类型,并且添加 block0 的输入信息。
然后添加必要的 prim::GetAttr 来替换有实际作用的 prim::TracedAttr 节点,并替换对应的输出
删除没用的 prim::TracedAttr 节点
给 TracedModuleForward 直接 = 号的形式,加输入输出名字 type等信息
处理多输出到一个 tuple
将 TracedModuleForward 的 block 拿来构建一个新的子图,然后将这个子图转成 node 的 Subgraph 属性:
类似于将图从:
%59 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = prim::TracedModuleForward[scope="__module.conv1"](%conv1), scope: __module.conv1
block0(%self.3 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d):
%weight.5 : Tensor = prim::GetAttr[name="weight"](%self.3)
%13 : NoneType = prim::Constant(), scope: __module.conv1
%14 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%15 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%16 : int[] = prim::ListConstruct(%14, %15), scope: __module.conv1
%17 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%18 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%19 : int[] = prim::ListConstruct(%17, %18), scope: __module.conv1
%20 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%21 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%22 : int[] = prim::ListConstruct(%20, %21), scope: __module.conv1
%23 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%24 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%25 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%26 : int[] = prim::ListConstruct(%24, %25), scope: __module.conv1
%27 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%28 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%29 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%30 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%31 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%input : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = aten::_convolution(%x, %weight.5, %13, %16, %19, %22, %23, %26, %27, %28, %29, %30, %31), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
-> (%input)
变成了:
%59 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = prim::TracedModuleForward[scope="__module.conv1", Subgraph=](%conv1, %x), scope: __module.conv1
block0(%self.3 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d):
%weight.5 : Tensor = prim::GetAttr[name="weight"](%self.3)
%13 : NoneType = prim::Constant(), scope: __module.conv1
%14 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%15 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%16 : int[] = prim::ListConstruct(%14, %15), scope: __module.conv1
%17 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%18 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%19 : int[] = prim::ListConstruct(%17, %18), scope: __module.conv1
%20 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%21 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%22 : int[] = prim::ListConstruct(%20, %21), scope: __module.conv1
%23 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%24 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%25 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%26 : int[] = prim::ListConstruct(%24, %25), scope: __module.conv1
%27 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%28 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%29 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%30 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%31 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%input : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = aten::_convolution(%x, %weight.5, %13, %16, %19, %22, %23, %26, %27, %28, %29, %30, %31), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
-> (%input)
block 已经作为属性传入 prim::TracedModuleForward 这个 node 了,这里就可以删除了:
%59 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = prim::TracedModuleForward[scope="__module.conv1", Subgraph=](%conv1, %x), scope: __module.conv1
最后将 prim::TracedModuleForward 转换成 prim::CallMethod 节点,并用 n->output()->replaceAllUsesWith(retval); 将新节点替换老节点:
%61 : Tensor = prim::CallMethod[name="forward"](%conv1, %x)
最后图的改变就是:
graph(%self.1 : __torch__.___torch_mangle_3.TestArangeModel,
%x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
%conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d = prim::GetAttr[name="conv2"](%self.1)
%conv1 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d = prim::GetAttr[name="conv1"](%self.1)
%61 : Tensor = prim::CallMethod[name="forward"](%conv1, %x)
%59 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = prim::TracedModuleForward_0[scope="__module.conv1"](%conv1, %x), scope: __module.conv1
%60 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = prim::TracedModuleForward_1[scope="__module.conv2"](%conv2, %61), scope: __module.conv2
return (%60)
最后就是删除不用节点:
graph(%self.1 : __torch__.___torch_mangle_3.TestArangeModel,
%x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
%conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d = prim::GetAttr[name="conv2"](%self.1)
%conv1 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d = prim::GetAttr[name="conv1"](%self.1)
%61 : Tensor = prim::CallMethod[name="forward"](%conv1, %x)
%60 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = prim::TracedModuleForward_0[scope="__module.conv2"](%conv2, %61), scope: __module.conv2
return (%60)
总结来说优化 Pass 主要是做了下面几件事:
在使用 prim::CallMethod 节点替换prim::TracedModuleForward 节点时,会同时将 block 转换为算子的 Subgraph 中的结果插入到 Submodule 的 methods 中,这样传入的 Module 也在改变。
重复上面步骤,最后我们 trace 出来的图也就是下面的样子了:
graph(%self.1 : __torch__.___torch_mangle_3.TestArangeModel,
%x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
%conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d = prim::GetAttr[name="conv2"](%self.1)
%conv1 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d = prim::GetAttr[name="conv1"](%self.1)
%61 : Tensor = prim::CallMethod[name="forward"](%conv1, %x)
%62 : Tensor = prim::CallMethod[name="forward"](%conv2, %61)
return (%62)
最后再看一下转化后的 module:
module __torch__.___torch_mangle_3.TestArangeModel {
parameters {
}
attributes {
training = True
_is_full_backward_hook = None
conv1 = <__torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d object at 0x559c691cd470>
conv2 = <__torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d object at 0x559c691cdcb0>
}
methods {
method forward {
graph(%self.1 : __torch__.___torch_mangle_3.TestArangeModel,
%x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
%conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d = prim::GetAttr[name="conv2"](%self.1)
%conv1 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d = prim::GetAttr[name="conv1"](%self.1)
%61 : Tensor = prim::CallMethod[name="forward"](%conv1, %x)
%62 : Tensor = prim::CallMethod[name="forward"](%conv2, %61)
return (%62)
}
}
submodules {
module __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d {
parameters {
weight = ...
}
attributes {
weight = ...
training = True
_is_full_backward_hook = None
}
methods {
method forward {
graph(%self.3 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d,
%x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
%weight.5 : Tensor = prim::GetAttr[name="weight"](%self.3)
%2 : NoneType = prim::Constant(), scope: __module.conv1
%3 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%4 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%5 : int[] = prim::ListConstruct(%3, %4), scope: __module.conv1
%6 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%7 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%8 : int[] = prim::ListConstruct(%6, %7), scope: __module.conv1
%9 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%10 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%11 : int[] = prim::ListConstruct(%9, %10), scope: __module.conv1
%12 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%13 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%14 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%15 : int[] = prim::ListConstruct(%13, %14), scope: __module.conv1
%16 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%17 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%18 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%19 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%20 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%input : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = aten::_convolution(%x, %weight.5, %2, %5, %8, %11, %12, %15, %16, %17, %18, %19, %20), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
return (%input)
}
}
submodules {
}
}
module __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d {
parameters {
weight = ...
bias = ...
}
attributes {
weight = ...
bias = ...
training = True
_is_full_backward_hook = None
}
methods {
method forward {
graph(%self : __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d,
%22 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
%bias : Tensor = prim::GetAttr[name="bias"](%self)
%weight : Tensor = prim::GetAttr[name="weight"](%self)
%3 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%4 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%5 : int[] = prim::ListConstruct(%3, %4), scope: __module.conv2
%6 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%7 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%8 : int[] = prim::ListConstruct(%6, %7), scope: __module.conv2
%9 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%10 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%11 : int[] = prim::ListConstruct(%9, %10), scope: __module.conv2
%12 : bool = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%13 : int = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%14 : int = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%15 : int[] = prim::ListConstruct(%13, %14), scope: __module.conv2
%16 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%17 : bool = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%18 : bool = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%19 : bool = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%20 : bool = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
%21 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = aten::_convolution(%22, %weight, %bias, %5, %8, %11, %12, %15, %16, %17, %18, %19, %20), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
return (%21)
}
}
submodules {
}
}
}
}
torch.jit.trace 就是一个将 module 转化成 jit graph 的一个过程,上面介绍了这个过程:
其中大量的工作都是对 Graph 中的节点进行操作,有需要的同学可以沿着上面介绍的内容,进一步深入源码阅读更多自己需要的内容。