【torch保存神经网络模型并加载和测试】

【torch保存神经网络模型并加载和测试】

1、方式一
保存:

torch.save(model, 'mlp_mnist.pth')  # model模型名,mlp_mnist.pth模型文件名

用这种方式保存模型,加载时需要加载模型结构,容易出错,推荐使用第二种方法。
加载的方式如下:

class MLP(nn.Module):
    def __init__(self, n_in, n_hid1, n_hid2, n_out):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(n_in, n_hid1)
        self.fc2 = nn.Linear(n_hid1, n_hid2)
        self.fc3 = nn.Linear(n_hid2, n_out)
        self.bn1 = nn.BatchNorm1d(n_hid1)
        
    def forward(self, x):
        x = x.view(x.shape[0], -1)
        x = self.fc1(x)
        x = self.bn1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        x = torch.relu(x)
        x = self.fc3(x)
        x = F.softmax(x, dim=1)
        return x
    

model7 = torch.load('mlp_mnist.pth')
print(model7)
print(model7.fc1.weight.size())
# 预测
x = torch.randn(1,1,28,28)
print(model7(x))

输出如下:

MLP(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=10, bias=True)
  (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
torch.Size([128, 784])
tensor([[1.7218e-08, 2.6122e-07, 4.3788e-04, 3.1814e-03, 1.0310e-17, 1.5054e-02,
         1.5403e-14, 9.8133e-01, 9.0798e-16, 1.0884e-13]],
       grad_fn=)

2、方式二
保存:

model_scripted = torch.jit.script(model)
model_scripted.save('mlp_scripted.pt')

注意:模型保存可以在每轮保存,或者最后一轮保存,或者隔一定轮数保存一次,防止程序因意外终止而导致模型被清除。
使用方式二加载模型,不需要网络结构,预测的代码如下,这是鸢尾花的例子,输入四个变量,输出3个类别概率,取最大概率序号,即为类别号

加载与预测:

#不必导入模型定义
model7 = torch.jit.load('mlp_scripted.pt')
print(model7)

# 预测
x = torch.randn(1,4)
print(model7(x))
pred = torch.argmax(model7(x), axis =1)
print(pred)

输出如下:

RecursiveScriptModule(
  original_name=Sequential
  (0): RecursiveScriptModule(original_name=Linear)
  (1): RecursiveScriptModule(original_name=ReLU)
  (2): RecursiveScriptModule(original_name=Linear)
  (3): RecursiveScriptModule(original_name=Softmax)
)
tensor([[0.2346, 0.7643, 0.0011]], grad_fn=)
tensor([1])

预测结果为1,属于第二类。

你可能感兴趣的:(Python编程基础,神经网络,python,机器学习)