C++ windows调用ubuntu训练的PyTorch模型(.pt/.pth)

之前写到用ubuntu上训练pytorch的网络,再window c++调用的文章。实际调用过程中,并不是简单load就可以,需要将PyTorch模型转换为Torch Script,在这里记录一下,供有需要的盆友参考。

PyTorch模型从Python到C++的转换由Torch Script实现。Torch Script是PyTorch模型的一种表示,可由Torch Script编译器理解,编译和序列化。

将PyTorch模型转换为Torch Script有两种方法。

  • 第一种方法是Tracing。该方法通过将样本输入到模型中一次来对该过程进行评估从而捕获模型结构.并记录该样本在模型中的flow。该方法适用于模型中很少使用控制flow的模型。(就是没有if...else... for ...这种的。。)
  • 第二个方法就是向模型添加显式注释(Annotation),通知Torch Script编译器它可以直接解析和编译模型代码,受Torch Script语言强加的约束。
利用Tracing将模型转换为Torch Script

要通过tracing来将PyTorch模型转换为Torch脚本,必须将模型的实例以及样本输入传递给torch.jit.trace函数。这将生成一个 torch.jit.ScriptModule对象,并在模块的forward方法中嵌入模型评估的跟踪(我用的就是这个):

import torch
import torchvision

# 获取模型实例
model = torchvision.models.resnet18()

# 生成一个样本供网络前向传播 forward()
example = torch.rand(1, 3, 224, 224)

# 使用 torch.jit.trace 生成 torch.jit.ScriptModule 来跟踪
traced_script_module = torch.jit.trace(model, example)

traced_script_module.save('model.pt')
通过Annotation将Model转换为Torch Script

在某些情况下,例如,如果模型使用特定形式的控制流,如果想要直接在Torch Script中编写模型并相应地标注(annotate)模型。例如,假设有以下普通的 Pytorch模型:

import torch

class MyModule(torch.nn.Module):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

由于此模块的forward方法使用依赖于输入的控制流,因此它不适合利用Tracing的方法生成Torch Script。为此,可以通过继承torch.jit.ScriptModule并将@ torch.jit.script_method标注添加到模型的forward中的方法,来将model转换为ScriptModule:

import torch

class MyModule(torch.jit.ScriptModule):
    def __init__(self, N, M):
        super(MyModule, self).__init__()
        self.weight = torch.nn.Parameter(torch.rand(N, M))

    @torch.jit.script_method
    def forward(self, input):
        if input.sum() > 0:
          output = self.weight.mv(input)
        else:
          output = self.weight + input
        return output

my_script_module = MyModule()

现在,创建一个新的MyModule对象会直接生成一个可序列化的ScriptModule实例了。

后面就可以参考之前的文章(windows+VS2019+PyTorchLib配置使用攻略
)进行调用了~~~( •̀ ω •́ )y

[参考链接]
https://pytorch.apachecn.org/docs/1.0/cpp_export.html

你可能感兴趣的:(C++ windows调用ubuntu训练的PyTorch模型(.pt/.pth))