pytorch的模型保存与读取

1. 模型的保存与读取语法

pytorch的模型和参数是分开的,可以分别保存或加载模型和参数。

1.1 保存与加载模型

# 保存模型
torch.save(model, 'model.pth')
# 加载模型
model = torch.load('model.pth')

注意:将模型保存成何种格式文件无所谓(比如pkl,pth等)。

1.2 保存与加载模型参数

# 保存模型参数
torch.save(model.state_dict(), 'model_params.pth')
# 加载模型参数
model.load_state_dict(torch.load('model_params.pth')

1.3 加载预训练模型

resnet18 = models.resnet(pretrained=True)

pretrained=False 表示只加载模型,不加载预训练参数,举一个例子:

# 加载模型
resnet18 = models.resnet18(pretrained=False)
# 加载预先下载好的预训练模型参数
resnet18.load_state_dict(torch.load('resnet18-5c106cde.pth'))

2. 案例实战

import torch
import matplotlib.pyplot as plt

# torch.manual_seed(1)    # reproducible

# fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)


def save():
    # save net1
    net1 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
    optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
    loss_func = torch.nn.MSELoss()

    for t in range(100):
        prediction = net1(x)
        loss = loss_func(prediction, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # plot result
    plt.figure(1, figsize=(10, 3))
    plt.subplot(131)
    plt.title('Net1')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)

    # 2 ways to save the net
    torch.save(net1, 'net.pkl')  # save entire net
    torch.save(net1.state_dict(), 'net_params.pkl')   # save only the parameters


def restore_net():
    # restore entire net1 to net2
    net2 = torch.load('net.pkl')
    prediction = net2(x)

    # plot result
    plt.subplot(132)
    plt.title('Net2')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)


def restore_params():
    # restore only the parameters in net1 to net3
    net3 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )

    # copy net1's parameters into net3
    net3.load_state_dict(torch.load('net_params.pkl'))
    prediction = net3(x)

    # plot result
    plt.subplot(133)
    plt.title('Net3')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.show()

# save net1
save()

# restore entire net (may slow)
restore_net()

# restore only the net parameters
restore_params()

pytorch的模型保存与读取_第1张图片

做以下几点解释:

  1. 在函数 save() 里定义了一种网络结构 net1,并对模型与参数进行了保存。
  2. 采用函数 restore_net() 加载模型时,是加载了一个含有训练好参数的新模型 net2,不过这个新模型与旧模型 net1 是一样的。
  3. 采用函数 restore_params() 加载参数时,需要先定义好模型 net3,并且这个模型需要与旧模型 net1 一样,即与保存参数的网络结构一致,否则不匹配会导致错误。

你可能感兴趣的:(DeepLearning)