pytorch中模型参数的保存与读取

对训练好的模型,有两种保存方法:
1.直接将训练好的神经网络进行保存,但是速度会比较慢
2.将训练好的神经网络参数,保存到文件当中,然后进行文件的读取,再将读出的参数赋给新建好的模型,要求新建好的模型与之前的模型相同

import torch
import torch.nn as nn
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim

#1构建数据集:y=x2
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)

#2.搭建神经网络
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.hidden=nn.Linear(1,10)
        self.predict=nn.Linear(10,1)

    def forward(self,x):
        x=F.relu(self.hidden(x))
        x=self.predict(x)
        return x

net=Net()

#3.定义优化器和损失函数
optimizer=torch.optim.SGD(net.parameters(),lr=0.5)
loss_func=torch.nn.MSELoss()

#4.训练模型
for epoch in range(1,101):
    prediction=net(x) #向前传播 得到预期值
    loss=loss_func(prediction,y) #向前传播 算出损失量 构建计算图

    print(epoch,loss)

    optimizer.zero_grad()
    loss.backward() #向后传播 算出梯度 释放计算图
    optimizer.step()# 梯度下降

print(net.hidden.weight.detach())  #net.hidden.weight.item()
print(net.hidden.bias.detach())
print(net.predict.weight.detach())
print(net.predict.weight.detach())

#保存模型参数至文件
torch.save(net.state_dict(), 'net_parameters.pt')
#实例化参数,赋值
m_state_dict = torch.load('net_parameters.pt')
new_net=Net()
new_net.load_state_dict(m_state_dict)

print(new_net.hidden.weight.detach())  #net.hidden.weight.item()
print(new_net.hidden.bias.detach())
print(new_net.predict.weight.detach())
print(new_net.predict.weight.detach())

#保存模型至文件
torch.save(net, 'net.pt')
#实例化模型
new_nets=torch.load('net.pt')

print(new_nets.hidden.weight.detach())  #net.hidden.weight.item()
print(new_nets.hidden.bias.detach())
print(new_nets.predict.weight.detach())
print(new_nets.predict.weight.detach())

pytorch中模型参数的保存与读取_第1张图片

你可能感兴趣的:(深度学习入门,神经网络,深度学习,卷积神经网络)