模型参数保存与提取

一、模型参数保存

1.保存整个网络

torch.save(net,'net.pkl')    #保存所有的网络参数

2.保存模型参数

torch.save(net.state_dict(),'net_parameter.pkl')    #保存优化选项默认字典,不保存网络结构

 后缀一般命名为 .pt 或者 .pth

二、网络提取

1.针对保存的是整个网络

net1 = torch.load('net.pkl')

2. 对于提取网络中的参数的方式,必须先完整的建立和需要提取的网络一样的结构的网络,再去提取参数进而恢复网络。

net2 = torch.nn.Sequential(
    nn.Linear(1,20),
    torch.nn.ReLU(),
    nn.Linear(20,20),
    torch.nn.ReLU(),
    nn.Linear(20,1)
)
net2.load_state_dict(torch.load('net_parameter.pkl'))

在大型神经网络中,网络的结构很复杂网络的参数也很发杂,所以直接保存整个网络会占用很大的磁盘资源,就本次实验的一个例子就可以看出,保存网络参数和保存网络结构对磁盘的占用是完全不同的,所以在大型神经网络中更倾向于用保存参数的方式去保存真个网络。

preview

你可能感兴趣的:(pytorch,深度学习,人工智能)