这主要有两种方法序列化和恢复模型。
第一种(推荐)只保存和加载模型参数:
torch.save(the_model.state_dict(), PATH)
然后:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
第二种保存和加载整个模型:
torch.save(the_model, PATH)
然后:
the_model = torch.load(PATH)
然而,在这种情况下,序列化的数据被绑定到特定的类和固定的目录结构,所以当在其他项目中使用时,或者在一些严重的重构器之后它可能会以各种方式break。
上面是官方文档给出的解释,可能容易理解但是落实到代码层面可能有点难以下手,所以下面给出一个具体的实例,展示一下如何使用torch.save()函数保存模型或模型参数,torch.load()加载模型.
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size())
x,y = Variable(x),Variable(y)
def save():
net = torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1),
)
optimizer = torch.optim.SGD(net.parameters(),lr=0.5)
loss_fun = torch.nn.MSELoss()
for t in range(100):
out = net(x)
loss = loss_fun(out,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save(net,'net.pkl')
torch.save(net.state_dict(),'net_params.pkl')
plt.figure(1,figsize=(10,3))
plt.subplot(131)
plt.title('Net')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),out.data.numpy(),'green',lw=5)
def restore_net():
net1 = torch.load('net.pkl')
out = net1(x)
plt.subplot(132)
plt.title('Net1')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),out.data.numpy(),'green',lw=5)
def restore_netparams():
net2 = torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1),
)
net2.load_state_dict(torch.load('net_params.pkl'))
out = net2(x)
plt.subplot(133)
plt.title('Net2')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),out.data.numpy(),'green',lw=5)
plt.show()
save()
restore_net()
restore_netparams()
这里详细解释下:
这个代码是我观看莫烦视频教程时候手码的,很久之前了,还是没有理解太透彻....
整个网络首先定义了三个函数
save() ###保存模型
restore_net() ###载入整个模型
restore_netparams() ##载入模型参数
在save()函数中首先通过nn.Sequential()快速创建基础网络,定义好optimizer和LOSS function后开始进行梯度清零,反向传播和梯度更新,100步之后保存模型
torch.save(net,'net.pkl') ##保存整个模型
torch.save(net.state_dict(),'net_params.pkl') ##保存整个模型参数
对于两种不同的方法定义不同函数进行模型的载入
对于整个模型的保存载入非常简单,直接使用torch.load(PATH),可以直接载入,但是不推荐这种方法,一是对于大量数据集训练迭代次数一般很多(50000...),所以模型一般会非常大,更重要的是这种保存的模型泛化性极其之差.
net1 = torch.load('net.pkl')
out = net1(x)
所以更为推荐下面这种方法
net2 = torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1),
)
net2.load_state_dict(torch.load('net_params.pkl'))
out = net2(x)
这种方法稍微复杂那么一点点,因为你之前只是保存的整个模型的参数,但是并没有给模型所以你需要调用的话,需要把模型框架copy过来,可以对比一下net和net2 没有区别,另外载入的时候指令也稍微长那么一点,一定要记住
net2.load_state_dict(torch.load('net_params.pkl'))
现在对这三行指令应该有点理解了吧
torch.save(the_model.state_dict(), PATH) ##保存模型
the_model = TheModelClass(*args, **kwargs) ##给定网络
the_model.load_state_dict(torch.load(PATH)) ##载入模型
另外呢,既然提到了网络,之前在使用caffe的时候也利用可视化工具看过Alexnet,vgg16,resnet等网络的模型结构图,具体可以看下之前的博客CNN经典分类模型--AlexNet、VGG16、ResNet网络结构图
如果想用代码实现具体网络的搭建,这里推荐pytorch给定的models文件https://github.com/pytorch/vision/tree/master/torchvision/models
每一个py文件都是具体网络的编写,如果想要直接调用的话也非常简单,
import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
squeezenet = models.squeezenet1_0()
densenet = models.densenet_161()
当然,如果你想使用pre-trained model ,只需要更改pretrained的布尔值为True就OK ;
import torchvision.models as models
#pretrained=True就可以使用预训练的模型
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
这里给出官方中文的具体介绍,推荐去看下
最后呢,感谢莫烦大佬的视频教程,非常适用于小编这样的初学者入门,给个传送门,去B站投币吧
莫烦PyTorch入门教程