Pytorch模型加载保存

Pytorch提供了两种模型的保存和加载方法。

一、首先定义一个名为Example的模型

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Example(nn.Module):
    def __init__(self):
        super(Example, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 初始化模型
model_init = Examle()

    保存的文件后缀一般使用.pt.pth,将保存后的模型使用到测试模式时,需要使用model.eval(),它表示将dropbatch nromalization层设置为测试模式。训练时候,需要通过设置mode.train()转化为训练模式。

1. 第一种是只保存模型的参数。

使用这种方法保存模型,当测试时候我们需要自己导入模型的结构信息。

#保存模型参数
torch.save(model.state_dict(), PATH)

例:将上面定义的模型model_init保存在文件夹ckp,以model.pth形式保存。

torch.save(model_init.state_dict(),'ckp/model.pth') 

当测试时候,我们需要加载保存的模型,来进行测试。
通用方法:

model = ModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))

加载我们刚刚保存的模型进行测试,当然我们需要使用定义的网络结构model_init

model_init .load_state_dict(torch.load('ckp/model.pth'))

2. 第二种是保存模型的结构和模型的参数

保存模型

torch.save(model, PATH)

例如:保存定义好的model_init模型

torch.save(model_init , 'ckp/model.pth')

加载模型

model = torch.load(PATH)
model.eval()

例如:加载保存好的model_init模型

model = torch.load('ckp/model.pth')
model.eval()

当然,第二种模型文件占用的内存要大于第一种方法。

你可能感兴趣的:(pytorch,深度学习)