PyTorch Eager mode and Script mode

文章目录

        • 前言
        • Script model
          • torch.jit.trace
          • torch.jit.script

本文大概总结一下近期对 pytorch 中的 eager 模式还有 script 模式的学习所得。

前言

断断续续接触这两个概念有很长一段时间了,但是始终觉得对这两个 pytorch 的重要特性的概念就是比较模糊,中间还夹杂了一个 JIT trace 的概念,让我一句话归纳总结它们就是:

  • Eager 模式:Python + Python runtime。这种模式是更 Pythonic 的编程模式,可以让用户很方便的使用 python 的语法来使用并调试框架,就像我们刚认识 pytorch 时它的样子,自带 eager 属性。(但是我始终对这个 eager 有点对不上号 T _ T)
  • Script 模式:TorchScript + PyTorch JIT。这种模式会对 eager 模式的模型创建一个中间表示(intermediate representation,IR),这个 IR 经过内部优化的,并且可以使用 PyTorch JIT 编译器去运行模型,不再依赖 python runtime,也可以使用 C++ 加载运行。

Script model

PyTorch 深受人们的喜爱主要是因为它的灵活和易用性(畏难心理,我到现在都还是对 TF 有点排斥),但是在模型部署方面,PyTorch 的表现却不尽人意,性能及可移植性都欠缺。之前使用 PyTorch 的痛点也是从研究到产品跨度比较大,不能直接将模型用来部署,为了解决这个 gap,PyTorch 提出了 TorchScript,想要通过它来实现从研究到产品的框架统一,通过TorchScript得到的模型可以脱离 python 的 runtime 并使你的模型跑的更快。

  • 可移植性:script 模式可以不用再使用 python runtime,因此可以用在多线程推理服务器,移动设备,自动驾驶等 python 很难应用的场景。
  • 性能表现:PyTorch JIT 是可以对 PyTorch 模型做特定优化的 JIT 编译器,其可以利用 runtime 的信息做量化,层融合,稀疏化等 Script 模型优化加速模型。

TorchScript 是一种编程语言,是 Python 的静态类型子集,它有自己的语法规则,我们使用 eager 模式来进行原型验证及训练的过程都是直接使用 python 语法,所以想得到方便部署的 script mode 需要通过torch.jit.trace 或者是 torch.jit.script 去处理模型。

torch.jit.trace

torch.jit.trace() 把训练后得到的 eager 模型以及模型需要的输入数据作为接口输入,然后 tracer 会把数据在 eager 模型里运行一次,并且记录执行的 tensor 操作,记录的结果会保存成一个 TorchScript 模块。

但是它的主要缺点就是不支持控制流,数据结构(list,dict 等)和 python 结构,并且可能部分操作没有正确的被记录在 TorchScript 模块中,但是不会给任何警示信息,不能保证输出的一定是正确的 TorchScript 模块。

torch.jit.script

torch.jit.script 用作装饰器可以将你的代码转化成写成 TorchScript 语言,它转化出来的模型更冗长(携带更多的信息),但是更通用,经过些许修改就可以支持大部分的 PyTorch 模型。 也可以用作接口,直接将 eager 模型送入torch.jit.script(),无需再送入数据。它支持控制流以及一些 Python 的数据结构。但是它会省略常量节点,并需要类型转换,如果没有类型提供则默认是 Tensor 类型。

因为 torch.jit.trace() 不支持控制流,torch.jit.script() 不会记录常量节点,当我们需要记录常量节点又需要支持控制流时就可以把二者结合在一起,下面直接贴出官方示例:

class MyRNNLoop(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))

    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h

rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)

可以得到:

def forward(self,
    xs: Tensor) -> Tuple[Tensor, Tensor]:
  h = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
  y = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
  y0 = y
  h0 = h
  for i in range(torch.size(xs, 0)):
    _0 = (self.cell).forward(torch.select(xs, 0, i), h0, )
    y1, h1, = _0
    y0, h0 = y1, h1
  return (y0, h0)

或者:

class WrapRNN(torch.nn.Module):
    def __init__(self):
        super(WrapRNN, self).__init__()
        self.loop = torch.jit.script(MyRNNLoop())

    def forward(self, xs):
        y, h = self.loop(xs)
        return torch.relu(y)

traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)

可以得到:

def forward(self,
    argument_1: Tensor) -> Tensor:
  _0, h, = (self.loop).forward(argument_1, )
  return torch.relu(h)

参考文章:

  • https://towardsdatascience.com/pytorch-jit-and-torchscript-c2a77bac0fff
  • https://pytorch.org/docs/stable/jit_language_reference.html#language-reference
  • https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html
  • https://pytorch.org/docs/stable/jit.html

你可能感兴趣的:(python,深度学习)