1、方式一
保存:
torch.save(model, 'mlp_mnist.pth') # model模型名,mlp_mnist.pth模型文件名
用这种方式保存模型,加载时需要加载模型结构,容易出错,推荐使用第二种方法。
加载的方式如下:
class MLP(nn.Module):
def __init__(self, n_in, n_hid1, n_hid2, n_out):
super(MLP, self).__init__()
self.fc1 = nn.Linear(n_in, n_hid1)
self.fc2 = nn.Linear(n_hid1, n_hid2)
self.fc3 = nn.Linear(n_hid2, n_out)
self.bn1 = nn.BatchNorm1d(n_hid1)
def forward(self, x):
x = x.view(x.shape[0], -1)
x = self.fc1(x)
x = self.bn1(x)
x = torch.relu(x)
x = self.fc2(x)
x = torch.relu(x)
x = self.fc3(x)
x = F.softmax(x, dim=1)
return x
model7 = torch.load('mlp_mnist.pth')
print(model7)
print(model7.fc1.weight.size())
# 预测
x = torch.randn(1,1,28,28)
print(model7(x))
输出如下:
MLP(
(fc1): Linear(in_features=784, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=64, bias=True)
(fc3): Linear(in_features=64, out_features=10, bias=True)
(bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
torch.Size([128, 784])
tensor([[1.7218e-08, 2.6122e-07, 4.3788e-04, 3.1814e-03, 1.0310e-17, 1.5054e-02,
1.5403e-14, 9.8133e-01, 9.0798e-16, 1.0884e-13]],
grad_fn=)
2、方式二
保存:
model_scripted = torch.jit.script(model)
model_scripted.save('mlp_scripted.pt')
注意:模型保存可以在每轮保存,或者最后一轮保存,或者隔一定轮数保存一次,防止程序因意外终止而导致模型被清除。
使用方式二加载模型,不需要网络结构,预测的代码如下,这是鸢尾花的例子,输入四个变量,输出3个类别概率,取最大概率序号,即为类别号
:
加载与预测:
#不必导入模型定义
model7 = torch.jit.load('mlp_scripted.pt')
print(model7)
# 预测
x = torch.randn(1,4)
print(model7(x))
pred = torch.argmax(model7(x), axis =1)
print(pred)
输出如下:
RecursiveScriptModule(
original_name=Sequential
(0): RecursiveScriptModule(original_name=Linear)
(1): RecursiveScriptModule(original_name=ReLU)
(2): RecursiveScriptModule(original_name=Linear)
(3): RecursiveScriptModule(original_name=Softmax)
)
tensor([[0.2346, 0.7643, 0.0011]], grad_fn=)
tensor([1])
预测结果为1,属于第二类。