断断续续接触这两个概念有很长一段时间了,但是始终觉得对这两个 pytorch 的重要特性的概念就是比较模糊,中间还夹杂了一个 JIT trace 的概念,让我一句话归纳总结它们就是:
PyTorch 深受人们的喜爱主要是因为它的灵活和易用性(畏难心理,我到现在都还是对 TF 有点排斥),但是在模型部署方面,PyTorch 的表现却不尽人意,性能及可移植性都欠缺。之前使用 PyTorch 的痛点也是从研究到产品跨度比较大,不能直接将模型用来部署,为了解决这个 gap,PyTorch 提出了 TorchScript,想要通过它来实现从研究到产品的框架统一,通过TorchScript得到的模型可以脱离 python 的 runtime 并使你的模型跑的更快。
TorchScript 是一种编程语言,是 Python 的静态类型子集,它有自己的语法规则,我们使用 eager 模式来进行原型验证及训练的过程都是直接使用 python 语法,所以想得到方便部署的 script mode 需要通过torch.jit.trace 或者是 torch.jit.script 去处理模型。
torch.jit.trace() 把训练后得到的 eager 模型以及模型需要的输入数据作为接口输入,然后 tracer 会把数据在 eager 模型里运行一次,并且记录执行的 tensor 操作,记录的结果会保存成一个 TorchScript 模块。
但是它的主要缺点就是不支持控制流,数据结构(list,dict 等)和 python 结构,并且可能部分操作没有正确的被记录在 TorchScript 模块中,但是不会给任何警示信息,不能保证输出的一定是正确的 TorchScript 模块。
torch.jit.script 用作装饰器可以将你的代码转化成写成 TorchScript 语言,它转化出来的模型更冗长(携带更多的信息),但是更通用,经过些许修改就可以支持大部分的 PyTorch 模型。 也可以用作接口,直接将 eager 模型送入torch.jit.script(),无需再送入数据。它支持控制流以及一些 Python 的数据结构。但是它会省略常量节点,并需要类型转换,如果没有类型提供则默认是 Tensor 类型。
因为 torch.jit.trace() 不支持控制流,torch.jit.script() 不会记录常量节点,当我们需要记录常量节点又需要支持控制流时就可以把二者结合在一起,下面直接贴出官方示例:
class MyRNNLoop(torch.nn.Module):
def __init__(self):
super(MyRNNLoop, self).__init__()
self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))
def forward(self, xs):
h, y = torch.zeros(3, 4), torch.zeros(3, 4)
for i in range(xs.size(0)):
y, h = self.cell(xs[i], h)
return y, h
rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)
可以得到:
def forward(self,
xs: Tensor) -> Tuple[Tensor, Tensor]:
h = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
y = torch.zeros([3, 4], dtype=None, layout=None, device=None, pin_memory=None)
y0 = y
h0 = h
for i in range(torch.size(xs, 0)):
_0 = (self.cell).forward(torch.select(xs, 0, i), h0, )
y1, h1, = _0
y0, h0 = y1, h1
return (y0, h0)
或者:
class WrapRNN(torch.nn.Module):
def __init__(self):
super(WrapRNN, self).__init__()
self.loop = torch.jit.script(MyRNNLoop())
def forward(self, xs):
y, h = self.loop(xs)
return torch.relu(y)
traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)
可以得到:
def forward(self,
argument_1: Tensor) -> Tensor:
_0, h, = (self.loop).forward(argument_1, )
return torch.relu(h)
参考文章: