Pytorch模型转Torchscript模型

TorchScript,它是PyTorch模型(子类nn.Module)的中间表示,可以在高性能环境(例如C ++)中运行。

转换的方法有两种,一种是通过追踪转换另一种是通过注释转换。

1、追踪转换

常用的是追踪转换,但是这种方法有一个缺点,就是输入尺寸固定。

官网给出的追踪的例子:

import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

The traced ScriptModule can now be evaluated identically to a regular PyTorch module:

In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))
In[2]: output[0, :5]
Out[2]: tensor([-0.2698, -0.0381,  0.4023, -0.3010, -0.0448], grad_fn=

你可能感兴趣的:(pytorch,深度学习,pytorch实践)