TorchScript 系列解读(一):初识 TorchScript

大家好呀,不久前我们推出了模型部署入门系列教程,受到了大家的一致好评,也收到了很多小伙伴的催更,后续教程正在准备中,将在不久后跟大家见面,敬请期待哦~

OpenMMLab:模型部署系列教程(一):模型部署简介78 赞同 · 14 评论文章正在上传…重新上传取消

OpenMMLab:模型部署入门系列教程(二):解决模型部署中的难题42 赞同 · 12 评论文章正在上传…重新上传取消

今天,我们又将开启新的 TorchScript 解读系列教程,带领大家玩转 PyTorch 模型部署。

什么是 TorchScript?

PyTorch 无疑是现在最成功的深度学习训练框架之一,是各种顶会顶刊论文实验的大热门。比起其他的框架,PyTorch 最大的卖点是它对动态网络的支持,比其他需要构建静态网络的框架拥有更低的学习成本。PyTorch 源码 Readme 中还专门为此做了一张动态图:

TorchScript 系列解读(一):初识 TorchScript_第1张图片


对研究员而言, PyTorch 能极大地提高想 idea、做实验、发论文的效率,是训练框架中的豪杰,但是它不适合部署。动态建图带来的优势对于性能要求更高的应用场景而言更像是缺点,非固定的网络结构给网络结构分析并进行优化带来了困难,多数参数都能以 Tensor 形式传输也让资源分配变成一件闹心的事。另外由于图是由 python 代码构建的,一方面部署要依赖 python 环境,另一方面模型也毫无保密性可言。

而 TorchScript 就是为了解决这个问题而诞生的工具。包括代码的追踪及解析、中间表示的生成、模型优化、序列化等各种功能,可以说是覆盖了模型部署的方方面面。今天我们先简要地介绍一些 TorchScript 的功能,让大家有一个初步的认识,进阶的解读会陆续推出~

模型转换

作为模型部署的一个范式,通常我们都需要生成一个模型的中间表示(IR),这个 IR 拥有相对固定的图结构,所以更容易优化,让我们看一个例子:

import torch 
from torchvision.models import resnet18 
 
# 使用PyTorch model zoo中的resnet18作为例子 
model = resnet18() 
model.eval() 
 
# 通过trace的方法生成IR需要一个输入样例 
dummy_input = torch.rand(1, 3, 224, 224) 
 
# IR生成 
with torch.no_grad(): 
    jit_model = torch.jit.trace(model, dummy_input) 

到这里就将 PyTorch 的模型转换成了 TorchScript 的 IR。这里我们使用了 trace 模式来生成 IR,所谓 trace 指的是进行一次模型推理,在推理的过程中记录所有经过的计算,将这些记录整合成计算图。关于 trace 的过程我们会在未来的分享中进行解读。

那么这个 IR 中到底都有些什么呢?我们可以可视化一下其中的 layer1 看看:

jit_layer1 = jit_model.layer1 
print(jit_layer1.graph) 
 
# graph(%self.6 : __torch__.torch.nn.modules.container.Sequential, 
#       %4 : Float(1, 64, 56, 56, strides=[200704, 3136, 56, 1], requires_grad=0, device=cpu)): 
#   %1 : __torch__.torchvision.models.resnet.___torch_mangle_10.BasicBlock = prim::GetAttr[name="1"](%self.6) 
#   %2 : __torch__.torchvision.models.resnet.BasicBlock = prim::GetAttr[name="0"](%self.6) 
#   %6 : Tensor = prim::CallMethod[name="forward"](%2, %4) 
#   %7 : Tensor = prim::CallMethod[name="forward"](%1, %6) 
#   return (%7) 

是不是有点摸不着头脑?TorchScript 有它自己对于 Graph 以及其中元素的定义,对于第一次接触的人来说可能比较陌生,但是没关系,我们还有另一种可视化方式:

print(jit_layer1.code) 
 
# def forward(self, 
#     argument_1: Tensor) -> Tensor: 
#   _0 = getattr(self, "1") 
#   _1 = (getattr(self, "0")).forward(argument_1, ) 
#   return (_0).forward(_1, ) 

没错,就是代码!TorchScript 的 IR 是可以还原成 python 代码的,如果你生成了一个 TorchScript 模型并且想知道它的内容对不对,那么可以通过这样的方式来做一些简单的检查。

刚才的例子中我们使用 trace 的方法生成IR。除了 trace 之外,PyTorch 还提供了另一种生成 TorchScript 模型的方法:script。这种方式会直接解析网络定义的 python 代码,生成抽象语法树 AST,因此这种方法可以解决一些 trace 无法解决的问题,比如对 branch/loop 等数据流控制语句的建图。script方式的建图有很多有趣的特性,会在未来的分享中做专题分析,敬请期待。

模型优化

聪明的同学可能发现了,上面的可视化中只有resnet18forward的部分,其中的子模块信息是不是丢失了呢?如果没有丢失,那么怎么样才能确定子模块的内容是否正确呢?别担心,还记得我们说过 TorchScript 支持对网络的优化吗,这里我们就可以用一个pass解决这个问题:

# 调用inline pass,对graph做变换 
torch._C._jit_pass_inline(jit_layer1.graph) 
print(jit_layer1.code) 
 
# def forward(self, 
#     argument_1: Tensor) -> Tensor: 
#   _0 = getattr(self, "1") 
#   _1 = getattr(self, "0") 
#   _2 = _1.bn2 
#   _3 = _1.conv2 
#   _4 = _1.bn1 
#   input = torch._convolution(argument_1, _1.conv1.weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True) 
#   _5 = _4.running_var 
#   _6 = _4.running_mean 
#   _7 = _4.bias 
#   input0 = torch.batch_norm(input, _4.weight, _7, _6, _5, False, 0.10000000000000001, 1.0000000000000001e-05, True) 
#   input1 = torch.relu_(input0) 
#   input2 = torch._convolution(input1, _3.weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True) 
#   _8 = _2.running_var 
#   _9 = _2.running_mean 
#   _10 = _2.bias 
#   out = torch.batch_norm(input2, _2.weight, _10, _9, _8, False, 0.10000000000000001, 1.0000000000000001e-05, True) 
#   input3 = torch.add_(out, argument_1, alpha=1) 
#   input4 = torch.relu_(input3) 
#   _11 = _0.bn2 
#   _12 = _0.conv2 
#   _13 = _0.bn1 
#   input5 = torch._convolution(input4, _0.conv1.weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True) 
#   _14 = _13.running_var 
#   _15 = _13.running_mean 
#   _16 = _13.bias 
#   input6 = torch.batch_norm(input5, _13.weight, _16, _15, _14, False, 0.10000000000000001, 1.0000000000000001e-05, True) 
#   input7 = torch.relu_(input6) 
#   input8 = torch._convolution(input7, _12.weight, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True) 
#   _17 = _11.running_var 
#   _18 = _11.running_mean 
#   _19 = _11.bias 
#   out0 = torch.batch_norm(input8, _11.weight, _19, _18, _17, False, 0.10000000000000001, 1.0000000000000001e-05, True) 
#   input9 = torch.add_(out0, input4, alpha=1) 
#   return torch.relu_(input9) 

这里我们就能看到卷积、batch_norm、relu等熟悉的算子了。

上面代码中我们使用了一个名为inlinepass,将所有子模块进行内联,这样我们就能看见更完整的推理代码。pass是一个来源于编译原理的概念,一个 TorchScript 的 pass 会接收一个图,遍历图中所有元素进行某种变换,生成一个新的图。我们这里用到的inline起到的作用就是将模块调用展开,尽管这样做并不能直接影响执行效率,但是它其实是很多其他pass的基础。PyTorch 中定义了非常多的 pass 来解决各种优化任务,未来我们会做一些更详细的介绍。

序列化

不管是哪种方法创建的 TorchScript 都可以进行序列化,比如:

# 将模型序列化 
jit_model.save('jit_model.pth') 
# 加载序列化后的模型 
jit_model = torch.jit.load('jit_model.pth') 

序列化后的模型不再与 python 相关,可以被部署到各种平台上。

PyTorch 提供了可以用于 TorchScript 模型推理的 c++ API,序列化后的模型终于可以不依赖 python 进行推理了:

// 加载生成的torchscript模型 
auto module = torch::jit::load('jit_model.pth'); 
// 根据任务需求读取数据 
std::vector inputs = ...; 
// 计算推理结果 
auto output = module.forward(inputs).toTensor(); 

与其他组件的关系?

与 torch.onnx 的关系?

TorchScript 系列解读(一):初识 TorchScript_第2张图片

TorchScript 系列解读(一):初识 TorchScript_第3张图片

ONNX 是业界广泛使用的一种神经网络中间表示,PyTorch 自然也对 ONNX 提供了支持。torch.onnx.export函数可以帮助我们把 PyTorch 模型转换成 ONNX 模型,这个函数会使用 trace 的方式记录 PyTorch 的推理过程。聪明的同学可能已经想到了,没错,ONNX 的导出,使用的正是 TorchScript 的 trace 工具。具体步骤如下:

  1. 使用 trace 的方式先生成一个 TorchScipt 模型,如果你转换的本身就是 TorchScript 模型,则可以跳过这一步。
  2. 使用许多 pass 对 1 中生成的模型进行变换,其中对 ONNX 导出最重要的一个 pass 就是ToONNX,这个 pass 会进行一个映射,将 TorchScript 中primaten空间下的算子映射到onnx空间下的算子。
  3. 使用 ONNX 的 proto 格式对模型进行序列化,完成 ONNX 的导出。

关于 ONNX 导出的实现以及算子映射的方式将会在未来的分享中详细展开。

与 torch.fx 的关系?

PyTorch1.9 开始添加了torch.fx工具,根据官方的介绍,它由符号追踪器(symbolic tracer),中间表示(IR), Python 代码生成(Python code generation)等组件组成,实现了python->python的翻译。是不是和 TorchScript 看起来有点像?

其实他们之间联系不大,可以算是互相垂直的两个工具,为解决两个不同的任务而诞生。

  • TorchScript 的主要用途是进行模型部署,需要记录生成一个便于推理优化的 IR,对计算图的编辑通常都是面向性能提升等等,不会给模型本身添加新的功能。
  • FX 的主要用途是进行python->python的翻译,它的 IR 中节点类型更简单,比如函数调用属性提取等等,这样的 IR 学习成本更低更容易编辑。使用 FX 来编辑图通常是为了实现某种特定功能,比如给模型插入量化节点等,避免手动编辑网络造成的重复劳动。

这两个工具可以同时使用,比如使用 FX 工具编辑模型来让训练更便利、功能更强大;然后用 TorchScript 将模型加速部署到特定平台。

希望通过以上的分享,大家对 TorchScript 有了一个初步的认识,未来我们将会为大家带来更进阶的解读,欢迎大家持续关注。另外值得分享的是,MMDeploy 已开始对 TorchScript 提供支持,欢迎大家访问 MMDeploy GitHub 主页 体验哦。

今天的分享就先到这里结束啦,你学会了吗?如果我们的分享给你带来一定的帮助,欢迎点赞收藏关注,比心~

你可能感兴趣的:(技术干货,pytorch,python,人工智能,深度学习)