[Pytorch].pth转.pt文件

Pytorch的模型文件一般会保存为.pth文件,C++接口一般读取的是.pt文件,因此,C++在调用Pytorch训练好的模型文件的时候就需要进行一个转换,转换为.pt文件,才能够读取。

所以在转换的时候,首先就需要先将模型文件读取进来,然后利用pytorch提供的函数torch.jit.trace进行转换,这个函数的声明为:

def trace(func,
          example_inputs,
          optimize=True,
          check_trace=True,
          check_inputs=None,
          check_tolerance=1e-5,
          _force_outplace=False,
          _module_class=None):
也就是,第一个参数为输入的模型,第二个参数为输入的带测试数据,通常其数据形式要跟模型的输入数据的形式是一样的。

转换的代码例子如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchsummary import summary
 
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5, 1)
        self.conv2 = nn.Conv2d(32, 64, 5, 1)
        self.fc1 = nn.Linear(4*4*64, 512)
        self.fc2 = nn.Linear(512, 10)
 
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
 
model = torch.load("mnist_cnn.pth")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
summary(model, input_size=(1, 28, 28))
model = model.to(device)
traced_script_module = torch.jit.trace(model, torch.ones(1, 1, 28, 28).to(device))
traced_script_module.save("mnist_cnn_cc1.pt")
 
————————————————
版权声明:本文为CSDN博主「熊叫大雄」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/yz2zcx/article/details/100609210

你可能感兴趣的:(#,Pytorch框架,pytorch,深度学习,神经网络)