[Pytorch 源码阅读] ——torch.trace.jit 接口 C++ 源码阅读

文章目录

      • 前言
      • torch::jit::Type
      • ClassType
      • c10::ivalue::Object
      • torch::jit::Object
      • torch::jit::Module
      • 转化过程
      • 总结

前言

本文主要是从 torch.jit.trace 接口,打开 Pytorch JIT 的大门,介绍在正常 nn.module 经过 Trace 之后形成 scriptModule 的过程和涉及到的 C++ 类,因为涉及到的内容蛮多的,所以这里就从源码的角度介绍了一些相对比较重要,或者是可以帮助我们理解的类。

在开始学习之前,不知道大家有没有想过一个看似简单但是也未必那么好回答的问题。编译语言为什么要分数据类型?

其实查找的大概意思就是,在计算机内部,为了实现不同的功能,会有不同的计算电路,对于这些不同电路对应到上层就是不同的数据类型。所以接触变成语言最基本的就是要了解它的数据类型。

torch::jit::Type

首先最重要的是类型。所以在梳理 jit 代码中,最后开始就落到了 torch::jit::Type 这个类上,这个基类表示不同的类型,目前类型可以表示的类型共 33 种,也包含了 IValue 的 Tag 中的类型。

// torch/include/ATen/core/jit_type_base.h 
#define C10_FORALL_TYPES(_) \
  _(AnyType)                \
  _(EnumType)               \
  _(AnyEnumType)            \
  _(TensorType)             \
  _(StorageType)            \
  _(TupleType)              \
  _(ListType)               \
  _(DictType)               \
  _(NumberType)             \
  _(FloatType)              \
  _(ComplexType)      \
  _(FutureType)             \
  _(RRefType)               \
  _(IntType)                \
  _(NoneType)               \
  _(StringType)             \
  _(GeneratorType)          \
  _(QuantizerType)          \
  _(BoolType)               \
  _(OptionalType)           \
  _(VarType)                \
  _(DeviceObjType)          \
  _(StreamObjType)          \
  _(FunctionType)           \
  _(ClassType)              \
  _(PyObjectType)           \
  _(CapsuleType)            \
  _(InterfaceType)          \
  _(QSchemeType)            \
  _(LayoutType)             \
  _(ScalarTypeType)         \
  _(AnyListType)            \
  _(AnyTupleType)           \
  _(AnyClassType)

enum class TypeKind {
#define DEFINE_TYPE(T) T,
  C10_FORALL_TYPES(DEFINE_TYPE)
#undef DEFINE_TYPE
};

Type 类型主要是定义了上面可 TypeKind 枚举,然后可以查询到相关的打印 name,是否是 SubType,是否是 module ,各类 cast 函数,然后就是获取数组中的 Type。下面是部分源码:

struct TORCH_API Type : std::enable_shared_from_this {
 private:
  TypeKind kind_;
 protected:
  Type(TypeKind kind) : kind_(kind) {}
 public:
  virtual bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const;
   virtual std::string str() const = 0;
  template 
  std::shared_ptr cast() {
    if (T::Kind == kind()) {
      return std::static_pointer_cast(shared_from_this());
    }
    return nullptr;
  }
  
  virtual at::ArrayRef containedTypes() const {
    return {};
  }
};

ClassType

在 pytorch/pytorch/torch/include/ATen/core/jit_type.h 路径下,对 Type 的各种子类进行了实现,这里比较重要的就是 TensorType 和 ClassType,鉴于前面已经介绍了关于 Tensor 的内容,这里就先不展开说了,重点介绍一下 ClassType。

类中定义了一个辅助类 ClassAttribute,类属性 ClassAttribute 主要是围绕了:name,kind(AttributeKind),Type(TypePtr),所以有了下面的类定义:

struct TORCH_API ClassAttribute {
  public:
  ClassAttribute(AttributeKind kind,
  TypePtr attributeType,
  std::string attributeName) :
    kind_(kind),
    attributeType_(attributeType),
    attributeName_(std::move(attributeName)) {}

  AttributeKind getKind() const {
    return kind_;
  }

  TypePtr getType() const {
    return attributeType_;
  }

  const std::string& getName() const {
    return attributeName_;
  }

  private:
  AttributeKind kind_;
  TypePtr attributeType_;
  std::string attributeName_;
};

其中:

enum class AttributeKind {
  BUFFER,
  PARAMETER,
  REGULAR_ATTRIBUTE
};

classType 的类定义也很长,下面罗列出相关内容:

struct TORCH_API ClassType : public NamedType {
  // Create a class type with name `name` and its methods stored in `cu`.
  static ClassTypePtr create(
      c10::optional qualifiedName,
      std::weak_ptr cu,
      bool is_module = false,
      std::string doc_string = "",
      std::vector unresolved_class_attributes = {});
  
  const std::vector& methods() const;
  std::string str() const override
  bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const override;
  ...
  //—————————————————————— 属性的操作 ————————————————————————
  TypePtr findAttribute(const std::string& name) const;
  TypePtr getAttribute(const std::string& name) const;
  size_t numAttributes() const;
  size_t addOrCheckAttribute(
      const std::string& name,
      TypePtr ty,
      bool is_parameter = false,
      bool is_buffer = false) 
  ...
 //—————————————————————— 常量节点的操作 ————————————————————————
  size_t addConstant(const std::string& name, const IValue& value);
  IValue getConstant(const std::string& name) const;
  at::ArrayRef constantValues() const;
  ...
 //—————————————————————— 函数相关的操作 ————————————————————————
  void addForwardPreHook(torch::jit::Function* pre_hook_ptr);
  void addForwardHook(torch::jit::Function* hook_ptr);
  const std::vector& getForwardHooks() const;
  const std::vector& getForwardPreHooks() const;
  void addMethod(torch::jit::Function* method);
  torch::jit::Function* findMethod(const std::string& name) const;
  torch::jit::Function& getMethod(const std::string& name) const;
  torch::jit::Function* findHook(const std::string& name) const;
  private:
  std::vector constantNames_;
  std::vector constantValues_;
  
  std::vector attributes_;
  std::vector attributeTypes_;
  
  std::vector methods_;
  std::vector staticmethods_;
  
  std::vector forward_hooks_;
  std::vector forward_pre_hooks_;
  ...
};

其中 QualifiedName 类用来代表形如:foo.bar.baz 这种格式的名字。 CompilationUnit 类可以看做是带名字的函数的 List,里面存储了类的函数,并提供了相关接口去遍历和调用相关函数。可以看到一个 ClassType 提供的类似于真实的类,除了继承的 Type 类的共用方法,主要成员函数集中在 属性(ClassAttribute),类的方法(Method,hook,prehook),还有常量节点(constant)的相关操作,最后就是构造时需要有一个 compilation_unit。

c10::ivalue::Object

前面定义了表示类的类型 ClassType,有了类类型可以创建的对象了。首先要介绍的是定义在 ivalue_inl.h 中的 c10::ivalue::Object,不是很长,下面先贴出类定义:

// torch/include/ATen/core/ivalue_inl.h
struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
 public:
  Object(StrongTypePtr type, size_t numSlots) : type_(std::move(type)) {
    slots_.resize(numSlots);
  }

  static c10::intrusive_ptr create(
      StrongTypePtr type,
      size_t numSlots) {
    return c10::make_intrusive(std::move(type), numSlots);
  }
  // slot 相关操作,就是 vector
  void setSlot(size_t slot, IValue v) {
    if (slot >= slots_.size()) {
      resizeObject(slot);
    }
    slots_[slot] = std::move(v);
  }
  
  const IValue& getSlot(size_t slot) const {
       return slots_[slot];
  }
  void unsafeRemoveSlot(size_t slot);
  // Attribute 也是一些 IValue,为了可以访问 module 的属性
  IValue getAttr(const std::string& name) const;
  void setAttr(const std::string& name, IValue v);
  void unsafeRemoveAttr(const std::string& name);

  std::string name() const;

  const std::vector& slots() const {
    return slots_;
  }
  std::shared_ptr type() const;
  // 深浅拷贝函数
  c10::intrusive_ptr copy() const;
  c10::intrusive_ptr deepcopy() const;

 private:
  void resizeObject(size_t slot);
  StrongTypePtr type_;
  std::vector slots_;
};

// 上述部分成员函数实现在: pytorch/pytorch/aten/src/ATen/core/ivalue.cpp, 感兴趣的读者可以自行阅读。
 
  

其中 StrongTypePtr 表示:

// torch/include/ATen/core/ivalue.h
struct TORCH_API StrongTypePtr {
  StrongTypePtr(
      std::shared_ptr cu,
      std::shared_ptr type);

  std::shared_ptr cu_;
  std::shared_ptr type_;
};

可以看到 ivalue::Object 类中主要是定义了一个 vector 的 slot_ 变量,用其来定义 module (或者说是 ClassType)的 Attribute,底层是 slot_,将 slot 包一层就是 attr。 然后外面 Module 添加 parameters 或者是 buffers 底层调用的都是 addAttribute,通过标志位来表示是否是 parameter / buffer / 其他属性。

然后就是torch::jit::Object 类了。

torch::jit::Object

// pytorch/pytorch/torch/csrc/jit/api/object.h

using ObjectPtr = c10::intrusive_ptr;

struct TORCH_API Object {
  Object() = default;
  Object(ObjectPtr _ivalue) : _ivalue_(std::move(_ivalue)) {}
  Object(std::shared_ptr cu, const c10::ClassTypePtr& type);
  Object(
      c10::QualifiedName,
      std::shared_ptr cu,
      bool shouldMangle = false);
  ObjectPtr _ivalue() const;

  c10::ClassTypePtr type() const {
    return _ivalue()->type();
  }
  
  void setattr(const std::string& name, c10::IValue v) {
    if (_ivalue()->type()->hasConstant(name)) {
      TORCH_CHECK(
          false,
          "Can't set constant '",
          name,
          "' which has value:",
          _ivalue()->type()->getConstant(name));
    } else if (auto slot = _ivalue()->type()->findAttributeSlot(name)) {
      const c10::TypePtr& expected = _ivalue()->type()->getAttribute(*slot);
      TORCH_CHECK(
          v.type()->isSubtypeOf(expected),
          "Expected a value of type '",
          expected->repr_str(),
          "' for field '",
          name,
          "', but found '",
          v.type()->repr_str(),
          "'");
      _ivalue()->setSlot(*slot, std::move(v));
    } else {
      TORCH_CHECK(false, "Module has no attribute '", name, "'");
    }
  }
  
  
c10::IValue attr(const std::string& name) const {
    if (auto r = _ivalue()->type()->findAttributeSlot(name)) {
      return _ivalue()->getSlot(*r);
    }
    if (auto r = _ivalue()->type()->findConstantSlot(name)) {
      return _ivalue()->type()->getConstant(*r);
    }
    std::stringstream err;
    err << _ivalue()->type()->repr_str() << " does not have a field with name '"
        << name.c_str() << "'";
    throw ObjectAttributeError(err.str());
  }

c10::IValue attr(const std::string& name, c10::IValue or_else) const {
    if (auto r = _ivalue()->type()->findAttributeSlot(name)) {
      return _ivalue()->getSlot(*r);
    }
    if (auto r = _ivalue()->type()->findConstantSlot(name)) {
      return _ivalue()->type()->getConstant(*r);
    }
    return or_else;
  }

bool hasattr(const std::string& name) const {
    return _ivalue()->type()->hasAttribute(name) ||
        _ivalue()->type()->hasConstant(name);
  }

  // 每一个 object 都有自己的 method
  Method get_method(const std::string& name) const {
    if (auto method = find_method(name)) {
      return *method;
    }
    AT_ERROR("Method '", name, "' is not defined.");
  }

  const std::vector get_methods() const {
    return c10::fmap(type()->methods(), [&](Function* func) {
      return Method(_ivalue(), func);
    });
  }
  
  c10::optional find_method(const std::string& basename) const;
  
  template 
  IValue run_method(const std::string& method_name, Types&&... args) {
    return get_method(method_name)({IValue(std::forward(args))...});
  }
  
  void define(const std::string& src, const ResolverPtr& resolver = nullptr);

  size_t num_slots() const {
    return _ivalue()->slots().size();
  }
  // 拷贝
  Object copy() const;
  Object deepcopy() const;
  
  private:
  mutable ObjectPtr _ivalue_;
};

torch::jit::Object 类是在 c10::ivalue::object 的基础上拓展的,例如相关拷贝实现就是直接调用的 c10::ivalue::object 实现。

可以看到这里的 torch::jit::Object 类的构造函数,其主要是由 ClassType 中指针构建出来的,cu 和 ClassTypePtr,从上面类的定义可以看出,torch::jit::Object 类主要来自 c10::ivalue::object 实现,相关操作主要来自 c10::ivalue::object 中包含的 ClassType 类接口。 _ivalue()->type() 。进行基本的 attr 和 method 的操作。

继承关系:ClassType::NamedType::Type,综上,其实 torch::jit::Object 类,主要就是通过 ivalue::Object + ClassType 来进行一些属性操作 和 Method 操作。包括对 Object 对象的深浅拷贝。 其中有一个 CompilationUnit 类型的 cu 指针,是创建对象的时候就有的,其实一个包含 Funciton 的 list。

torch::jit::Module

首先要明确一下的概念:

“模块”是对某些函数或算法实现的抽象,一个 Module 类主要有以下几个属性:

  • Buffers: 主要是不记录梯度的 tensor,一般在前向传播中会进行更新,例如 BatchNorm 算子的 mean 和 variance,这类的运行统计数据。
  • Parameters: 会记录梯度的 tensor,例如通常会在反向传播中会更新的权重。
  • 其他状态:不一定是 tensor 类型,但是在 Module 实现或者配置中需要用到的状态量。

与 ClassType 中的 AttributeKind 是对应的,注意当 Module clone 的时候将会深拷贝。Module 提供了 register_parameter 和 register_buffer 来分别注册两种不同的 Tensor。

Module 类是可以嵌套的,即一个 Module 可以有 submodule。

基于上面的了解,下面就进入到 torch::jit::Module 的类,来探索相关操作。

// pytorch/pytorch/torch/csrc/jit/api/module.h
struct TORCH_API Module : public Object {
  // 一堆初始化函数,这里初始化函数基本包含了 Object 的初始化样式
  explicit Module(c10::QualifiedName class_name);
  Module(std::shared_ptr cu, const c10::ClassTypePtr& type);
  Module() = default;
  Module(
      c10::QualifiedName,
      std::shared_ptr cu,
      bool shouldMangle = false);
  Module(ModulePtr module_value) : Object(std::move(module_value)) {}
  ~Module() = default;

  // 主要的 forward 函数
  IValue forward(std::vector inputs) {
    return get_method("forward")(std::move(inputs));
  }
  
  // 相关 Module 重要元素注册
 void register_buffer(const std::string& name, at::Tensor v) {
    bool is_param = false;
    bool is_buffer = true;
    type()->addOrCheckAttribute(name, TensorType::get(), is_param, is_buffer);
    _ivalue()->setAttr(name, std::move(v));
  }

  void register_parameter(
      const std::string& name,
      at::Tensor v,
      bool is_buffer) {
    type()->addOrCheckAttribute(name, TensorType::get(), !is_buffer, is_buffer);
    _ivalue()->setAttr(name, std::move(v));
  }
  
  // 递归应用相关 fn
  void apply(const std::function& fn);
  ...

  // 相关实际内容获取
  buffer_list buffers(bool recurse = true) const;
  named_buffer_list named_buffers(bool recurse = true) const;

  module_list children() const; // direct modules
  named_module_list named_children() const;
  module_list modules() const; // all modules, including this one, recursively
  named_module_list named_modules() const;

  // all tensors involved in gradient optimization
  parameter_list parameters(bool recurse = true) const;
  named_parameter_list named_parameters(bool recurse = true) const;

  // all members of the object, similar to iterating over dir(obj) in python
  attribute_list attributes(bool recurse = true) const;
  named_attribute_list named_attributes(bool recurse = true) const;
  
  // dump 函数
  void dump(
      bool print_method_bodies,
      bool print_attr_values,
      bool print_param_values) const;
   ...
     
  // 还包含其他的 to(), save(),copy(),deepcopy(),clone(),clone_method() 等成员函数,这里就不展开介绍了。
};

这里当我们向 torch.jit.trace 里面传入一个 nn.module,中间 .py 中会转成 script module,函数定义为:

class TestArangeModel(torch.nn.Module):
    def __init__(self):
        super(TestArangeModel, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 3, 3, stride=(1, 1), padding=(1, 1),bias=False)
        self.conv2 = torch.nn.Conv2d(3, 3, 3, stride=(1, 1), padding=(1, 1))

    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

m = TestArangeModel()
input = torch.randn(1,3,5,5)
traced = torch.jit.trace(m,input)

使用 dump 接口打印出来的传入 jit::Module 大概长这样:

module __torch__.TestArangeModel {
  parameters {
  }
  attributes {
    training = True
    _is_full_backward_hook = None
    conv1 = <__torch__.torch.nn.modules.conv.Conv2d object at 0x557569c9f300>
    conv2 = <__torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d object at 0x557569cb4080>
  }
  methods {
  }
  submodules {
    module __torch__.torch.nn.modules.conv.Conv2d {
      parameters {
        weight = ...
      }
      attributes {
        weight = ...
        training = True
        _is_full_backward_hook = None
      }
      methods {
      }
      submodules {
      }
    } // 下面的 QualifiedName 就是: __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d
    module __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d {
      parameters {
        weight = ...  // parameter 和 buffer 都在 parameters 中,底层会有表示
        bias = ...
      }
      attributes {
        weight = ... // 注册的时候parameter 和 buffer也会同事加入到 attr 中,同其他属性一起
        bias = ...
        training = True
        _is_full_backward_hook = None
      }
      methods {   // method
      }
      submodules { // module 本身是可嵌套的
      }
    }
  }
}

torch.jit.trace 主要就是通过将 torch::jit::Module 转成 torch::jit::Graph 的这样一个过程,重点就是转成 Graph 结构之后对图做的一系列变换,通过中间一个名为 TracingState 的类来追踪中间的状态,最后返回一个转换后的 torch::jit::Graph。

转化过程

上面对涉及到的类进行了一些介绍,有了上面的内容,我们在本部分主要说明一下 jit.trace 接口在 C++ 部分工作的大概流程。

由上面可以看到初始的时候 Module 里面是没有 methods 的,(可能是因为是从 nn.Module 转过来的,本身类定义 nn.Module 就没有 method 相关的内容)。接下来就一步步探究如何构建成图的,包括最后 Module 会长啥样。

  1. 首先会递归的将 module 对象的所有 attribute 还有输入输出依次用节点表示出来。
graph(%self : __torch__.___torch_mangle_3.TestArangeModel,
      %x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
  %1 : bool = prim::TracedAttr[scope="__module.training"]()
  %2 : NoneType = prim::TracedAttr[scope="__module._is_full_backward_hook"]()
  %3 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d = prim::TracedAttr[scope="__module.conv1"]()
  %4 : Float(3, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=1, device=cpu) = prim::TracedAttr[scope="__module.conv1.weight"]()
  %5 : bool = prim::TracedAttr[scope="__module.conv1.training"]()
  %6 : NoneType = prim::TracedAttr[scope="__module.conv1._is_full_backward_hook"]()
  %7 : __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d = prim::TracedAttr[scope="__module.conv2"]()
  %8 : Float(3, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=1, device=cpu) = prim::TracedAttr[scope="__module.conv2.weight"]()
  %9 : Float(3, strides=[1], requires_grad=1, device=cpu) = prim::TracedAttr[scope="__module.conv2.bias"]()
  %10 : bool = prim::TracedAttr[scope="__module.conv2.training"]()
  %11 : NoneType = prim::TracedAttr[scope="__module.conv2._is_full_backward_hook"]()
  return ()

上述 Module 转化成 Graph 之后就是这个效果。

经过传入的 traced 函数之后,会得到:

graph(%self : __torch__.TestArangeModel,
      %x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
  %1 : bool = prim::TracedAttr[scope="__module.training"]()
  %2 : NoneType = prim::TracedAttr[scope="__module._is_full_backward_hook"]()
  %3 : __torch__.torch.nn.modules.conv.Conv2d = prim::TracedAttr[scope="__module.conv1"]()
  %weight.1 : Float(3, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=1, device=cpu) = prim::TracedAttr[scope="__module.conv1.weight"]()
  %5 : bool = prim::TracedAttr[scope="__module.conv1.training"]()
  %6 : NoneType = prim::TracedAttr[scope="__module.conv1._is_full_backward_hook"]()
  %7 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d = prim::TracedAttr[scope="__module.conv2"]()
  %weight : Float(3, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=1, device=cpu) = prim::TracedAttr[scope="__module.conv2.weight"]()
  %bias : Float(3, strides=[1], requires_grad=1, device=cpu) = prim::TracedAttr[scope="__module.conv2.bias"]()
  %10 : bool = prim::TracedAttr[scope="__module.conv2.training"]()
  %11 : NoneType = prim::TracedAttr[scope="__module.conv2._is_full_backward_hook"]()
   = prim::TracedModuleForward[scope="__module.conv1"](), scope: __module.conv1
    block0():
      %13 : NoneType = prim::Constant(), scope: __module.conv1
      %14 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %15 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %16 : int[] = prim::ListConstruct(%14, %15), scope: __module.conv1
      %17 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %18 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %19 : int[] = prim::ListConstruct(%17, %18), scope: __module.conv1
      %20 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %21 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %22 : int[] = prim::ListConstruct(%20, %21), scope: __module.conv1
      %23 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %24 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %25 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %26 : int[] = prim::ListConstruct(%24, %25), scope: __module.conv1
      %27 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %28 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %29 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %30 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %31 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %input : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=1, device=cpu) = aten::_convolution(%x, %weight.1, %13, %16, %19, %22, %23, %26, %27, %28, %29, %30, %31), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      -> ()
   = prim::TracedModuleForward[scope="__module.conv2"](), scope: __module.conv2
    block0():
      %33 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %34 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %35 : int[] = prim::ListConstruct(%33, %34), scope: __module.conv2
      %36 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %37 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %38 : int[] = prim::ListConstruct(%36, %37), scope: __module.conv2
      %39 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %40 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %41 : int[] = prim::ListConstruct(%39, %40), scope: __module.conv2
      %42 : bool = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %43 : int = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %44 : int = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %45 : int[] = prim::ListConstruct(%43, %44), scope: __module.conv2
      %46 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %47 : bool = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %48 : bool = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %49 : bool = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %50 : bool = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      %51 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=1, device=cpu) = aten::_convolution(%input, %weight, %bias, %35, %38, %41, %42, %45, %46, %47, %48, %49, %50), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
      -> ()
  return (%51)

接下来就是要经过各种 pass 来对当前这个 graph 做进一步的处理。要进入 pytorch/pytorch/torch/csrc/jit/passes/fixup_trace_scope_blocks.cpp 中的相关优化 pass 函数:

void FixupTraceScopeBlocks(std::shared_ptr& graph, Module* self) {
  if (self) {
    ConvertTracedAttrReferences().run(graph);
  } else {
    for (Node* n : graph->nodes()) {
      TORCH_INTERNAL_ASSERT(n->kind() != prim::TracedAttr);
    }
  }
  MakeDefsDominateUses().run(graph->block());
  convertReturnsToTuples(graph->block());
  if (!self) {
    // We have no Module, so we're just going to inline everything.
    // This should give us a totally flat graph.
    inlineScopeBlocks(graph->block());
    // For TracedFork nodes
    lambdaLiftBlocksAndConvertToGraph(graph->block());
    runCleanupPasses(graph);
  } else {
    lambdaLiftBlocksAndConvertToGraph(graph->block());
    createMethodCalls(graph);
    runCleanupPasses(self);
    // `graph` isn't referenced in `self` yet, so we need to run
    // this separately
    runCleanupPasses(graph);
  }
}

上面函数主要思路还是处理图中节点的关系,然后再处理 prim::Tracedxxx 类型的节点,大概做了以下几个步骤:

  1. 首先递归的找到 prim::TracedModuleForward 类型节点,然后找到带有同样 scope 的 prim::TracedAttr 从输入的 module 中获得相关子 module 的节点,然后作为参数传入 prim::TracedModuleForward 类型,并且添加 block0 的输入信息。

  2. 然后添加必要的 prim::GetAttr 来替换有实际作用的 prim::TracedAttr 节点,并替换对应的输出

  3. 删除没用的 prim::TracedAttr 节点

  4. 给 TracedModuleForward 直接 = 号的形式,加输入输出名字 type等信息

  5. 处理多输出到一个 tuple

  6. 将 TracedModuleForward 的 block 拿来构建一个新的子图,然后将这个子图转成 node 的 Subgraph 属性:

    类似于将图从:

%59 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = prim::TracedModuleForward[scope="__module.conv1"](%conv1), scope: __module.conv1
  block0(%self.3 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d):
    %weight.5 : Tensor = prim::GetAttr[name="weight"](%self.3)
    %13 : NoneType = prim::Constant(), scope: __module.conv1
    %14 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %15 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %16 : int[] = prim::ListConstruct(%14, %15), scope: __module.conv1
    %17 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %18 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %19 : int[] = prim::ListConstruct(%17, %18), scope: __module.conv1
    %20 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %21 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %22 : int[] = prim::ListConstruct(%20, %21), scope: __module.conv1
    %23 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %24 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %25 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %26 : int[] = prim::ListConstruct(%24, %25), scope: __module.conv1
    %27 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %28 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %29 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %30 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %31 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %input : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = aten::_convolution(%x, %weight.5, %13, %16, %19, %22, %23, %26, %27, %28, %29, %30, %31), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    -> (%input)

变成了:

%59 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = prim::TracedModuleForward[scope="__module.conv1", Subgraph=](%conv1, %x), scope: __module.conv1
  block0(%self.3 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d):
    %weight.5 : Tensor = prim::GetAttr[name="weight"](%self.3)
    %13 : NoneType = prim::Constant(), scope: __module.conv1
    %14 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %15 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %16 : int[] = prim::ListConstruct(%14, %15), scope: __module.conv1
    %17 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %18 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %19 : int[] = prim::ListConstruct(%17, %18), scope: __module.conv1
    %20 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %21 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %22 : int[] = prim::ListConstruct(%20, %21), scope: __module.conv1
    %23 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %24 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %25 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %26 : int[] = prim::ListConstruct(%24, %25), scope: __module.conv1
    %27 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %28 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %29 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %30 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %31 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    %input : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = aten::_convolution(%x, %weight.5, %13, %16, %19, %22, %23, %26, %27, %28, %29, %30, %31), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
    -> (%input)

block 已经作为属性传入 prim::TracedModuleForward 这个 node 了,这里就可以删除了:

%59 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = prim::TracedModuleForward[scope="__module.conv1", Subgraph=](%conv1, %x), scope: __module.conv1

最后将 prim::TracedModuleForward 转换成 prim::CallMethod 节点,并用 n->output()->replaceAllUsesWith(retval); 将新节点替换老节点:

%61 : Tensor = prim::CallMethod[name="forward"](%conv1, %x)

最后图的改变就是:

graph(%self.1 : __torch__.___torch_mangle_3.TestArangeModel,
      %x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
  %conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d = prim::GetAttr[name="conv2"](%self.1)
  %conv1 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d = prim::GetAttr[name="conv1"](%self.1)
  %61 : Tensor = prim::CallMethod[name="forward"](%conv1, %x)
  %59 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = prim::TracedModuleForward_0[scope="__module.conv1"](%conv1, %x), scope: __module.conv1
  %60 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = prim::TracedModuleForward_1[scope="__module.conv2"](%conv2, %61), scope: __module.conv2
  return (%60)

最后就是删除不用节点:

graph(%self.1 : __torch__.___torch_mangle_3.TestArangeModel,
      %x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
  %conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d = prim::GetAttr[name="conv2"](%self.1)
  %conv1 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d = prim::GetAttr[name="conv1"](%self.1)
  %61 : Tensor = prim::CallMethod[name="forward"](%conv1, %x)
  %60 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = prim::TracedModuleForward_0[scope="__module.conv2"](%conv2, %61), scope: __module.conv2
  return (%60)

总结来说优化 Pass 主要是做了下面几件事:

  • 用prim::GetAttr来替换 prim::TracedAttr,并删除多余的 prim::TracedAttr 节点。(多余是没有使用者)
  • 处理 prim::TracedModuleForward 节点,主要是先梳理好输入输出信息,然后将 block 转换为算子的 Subgraph 属性, 最后用 prim::CallMethod 节点来替换prim::TracedModuleForward 节点。

在使用 prim::CallMethod 节点替换prim::TracedModuleForward 节点时,会同时将 block 转换为算子的 Subgraph 中的结果插入到 Submodule 的 methods 中,这样传入的 Module 也在改变。

重复上面步骤,最后我们 trace 出来的图也就是下面的样子了:

graph(%self.1 : __torch__.___torch_mangle_3.TestArangeModel,
      %x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
  %conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d = prim::GetAttr[name="conv2"](%self.1)
  %conv1 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d = prim::GetAttr[name="conv1"](%self.1)
  %61 : Tensor = prim::CallMethod[name="forward"](%conv1, %x)
  %62 : Tensor = prim::CallMethod[name="forward"](%conv2, %61)
  return (%62)

最后再看一下转化后的 module:

module __torch__.___torch_mangle_3.TestArangeModel {
  parameters {
  }
  attributes {
    training = True
    _is_full_backward_hook = None
    conv1 = <__torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d object at 0x559c691cd470>
    conv2 = <__torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d object at 0x559c691cdcb0>
  }
  methods {
    method forward {
      graph(%self.1 : __torch__.___torch_mangle_3.TestArangeModel,
            %x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
        %conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d = prim::GetAttr[name="conv2"](%self.1)
        %conv1 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d = prim::GetAttr[name="conv1"](%self.1)
        %61 : Tensor = prim::CallMethod[name="forward"](%conv1, %x)
        %62 : Tensor = prim::CallMethod[name="forward"](%conv2, %61)
        return (%62)
  
    }
  }
  submodules {
    module __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d {
      parameters {
        weight = ...
      }
      attributes {
        weight = ...
        training = True
        _is_full_backward_hook = None
      }
      methods {
        method forward {
          graph(%self.3 : __torch__.torch.nn.modules.conv.___torch_mangle_1.Conv2d,
                %x : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
            %weight.5 : Tensor = prim::GetAttr[name="weight"](%self.3)
            %2 : NoneType = prim::Constant(), scope: __module.conv1
            %3 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %4 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %5 : int[] = prim::ListConstruct(%3, %4), scope: __module.conv1
            %6 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %7 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %8 : int[] = prim::ListConstruct(%6, %7), scope: __module.conv1
            %9 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %10 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %11 : int[] = prim::ListConstruct(%9, %10), scope: __module.conv1
            %12 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %13 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %14 : int = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %15 : int[] = prim::ListConstruct(%13, %14), scope: __module.conv1
            %16 : int = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %17 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %18 : bool = prim::Constant[value=0](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %19 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %20 : bool = prim::Constant[value=1](), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %input : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = aten::_convolution(%x, %weight.5, %2, %5, %8, %11, %12, %15, %16, %17, %18, %19, %20), scope: __module.conv1 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            return (%input)
      
        }
      }
      submodules {
      }
    }
    module __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d {
      parameters {
        weight = ...
        bias = ...
      }
      attributes {
        weight = ...
        bias = ...
        training = True
        _is_full_backward_hook = None
      }
      methods {
        method forward {
          graph(%self : __torch__.torch.nn.modules.conv.___torch_mangle_2.Conv2d,
                %22 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu)):
            %bias : Tensor = prim::GetAttr[name="bias"](%self)
            %weight : Tensor = prim::GetAttr[name="weight"](%self)
            %3 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %4 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %5 : int[] = prim::ListConstruct(%3, %4), scope: __module.conv2
            %6 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %7 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %8 : int[] = prim::ListConstruct(%6, %7), scope: __module.conv2
            %9 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %10 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %11 : int[] = prim::ListConstruct(%9, %10), scope: __module.conv2
            %12 : bool = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %13 : int = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %14 : int = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %15 : int[] = prim::ListConstruct(%13, %14), scope: __module.conv2
            %16 : int = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %17 : bool = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %18 : bool = prim::Constant[value=0](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %19 : bool = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %20 : bool = prim::Constant[value=1](), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            %21 : Float(1, 3, 5, 5, strides=[75, 25, 5, 1], requires_grad=0, device=cpu) = aten::_convolution(%22, %weight, %bias, %5, %8, %11, %12, %15, %16, %17, %18, %19, %20), scope: __module.conv2 # /home2/anaconda3/lib/python3.7/site-packages/torch/nn/modules/conv.py:443:0
            return (%21)
        }
      }
      submodules {
      }
    }
  }
}

总结

torch.jit.trace 就是一个将 module 转化成 jit graph 的一个过程,上面介绍了这个过程:

  • 递归处理 Module 的属性和子图,以此来生成 Graph 的节点
  • 再对这些中间 Graph 的节点进行转化,处理和优化,最后得到一个我们可以打印出来的 torch.jit.Graph

其中大量的工作都是对 Graph 中的节点进行操作,有需要的同学可以沿着上面介绍的内容,进一步深入源码阅读更多自己需要的内容。

你可能感兴趣的:(源码阅读,C++,深度学习)