pytorch模型转torchscript

文章目录

    • 目的
    • 方法
        • trace
        • script

目的

将pytorch模型转化成torchscript目的就是为了可以在c++环境中调用pytorch模型。
pytorch官方链接

方法

共有两种方法将pytorch模型转成torch script ,一种是trace,另一种是script。一版在模型内部没有控制流存在的话(if,for循环),直接用trace方法就可以了。如果模型内部存在控制流,那就需要用到script方法了。

trace

通过使用示例输入对模型的结构进行一次评估,并记录这些输入在模型中的变化过程,从而捕获模型的结构。

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule,self).__init__()
        self.conv1 = nn.Conv2d(1,3,3)

    def forward(self,x):
    	x = self.conv1(x)
        return x

model = MyModule()  # 实例化模型
trace_module = torch.jit.trace(model,torch.rand(1,1,224,224)) 
print(trace_module.code)  # 查看模型结构
output = trace_module (torch.ones(1, 3, 224, 224)) # 测试
print(output)
trace_modult('model.pt') # 模型保存

script

如果模型内部有控制流结构,用trace就会报错。

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule,self).__init__()
        self.conv1 = nn.Conv2d(1,3,3)
        self.conv2 = nn.Conv2d(2,3,3)

    def forward(self,x):
        b,c,h,w = x.shape
        if c ==1:
            x = self.conv1(x)
        else:
            x = self.conv2(x)
        return x

model = MyModule()

# 这样写会报错,因为有控制流
# trace_module = torch.jit.trace(model,torch.rand(1,1,224,224)) 

# 此时应该用script方法
script_module = torch.jit.script(model) 
print(script_module.code)
output = script_module(torch.rand(1,1,224,224))

你可能感兴趣的:(深度学习,机器学习,pytorch,机器学习,数据挖掘)