之前写到用ubuntu上训练pytorch的网络,再window c++调用的文章。实际调用过程中,并不是简单load就可以,需要将PyTorch模型转换为Torch Script,在这里记录一下,供有需要的盆友参考。
PyTorch模型从Python到C++的转换由Torch Script实现。Torch Script是PyTorch模型的一种表示,可由Torch Script编译器理解,编译和序列化。
将PyTorch模型转换为Torch Script有两种方法。
- 第一种方法是Tracing。该方法通过将样本输入到模型中一次来对该过程进行评估从而捕获模型结构.并记录该样本在模型中的flow。该方法适用于模型中很少使用控制flow的模型。(就是没有
if...else... for ...
这种的。。) - 第二个方法就是向模型添加显式注释(Annotation),通知Torch Script编译器它可以直接解析和编译模型代码,受Torch Script语言强加的约束。
利用Tracing将模型转换为Torch Script
要通过tracing来将PyTorch模型转换为Torch脚本,必须将模型的实例以及样本输入传递给torch.jit.trace
函数。这将生成一个 torch.jit.ScriptModule
对象,并在模块的forward
方法中嵌入模型评估的跟踪(我用的就是这个):
import torch
import torchvision
# 获取模型实例
model = torchvision.models.resnet18()
# 生成一个样本供网络前向传播 forward()
example = torch.rand(1, 3, 224, 224)
# 使用 torch.jit.trace 生成 torch.jit.ScriptModule 来跟踪
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save('model.pt')
通过Annotation将Model转换为Torch Script
在某些情况下,例如,如果模型使用特定形式的控制流,如果想要直接在Torch Script中编写模型并相应地标注(annotate)模型。例如,假设有以下普通的 Pytorch模型:
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
由于此模块的forward
方法使用依赖于输入的控制流,因此它不适合利用Tracing
的方法生成Torch Script。为此,可以通过继承torch.jit.ScriptModule
并将@ torch.jit.script_method
标注添加到模型的forward
中的方法,来将model转换为ScriptModule:
import torch
class MyModule(torch.jit.ScriptModule):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
@torch.jit.script_method
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
my_script_module = MyModule()
现在,创建一个新的MyModule对象会直接生成一个可序列化的ScriptModule实例了。
后面就可以参考之前的文章(windows+VS2019+PyTorchLib配置使用攻略
)进行调用了~~~( •̀ ω •́ )y
[参考链接]
https://pytorch.apachecn.org/docs/1.0/cpp_export.html