(Pytorch)模型的保存与加载

对于一个已经训练完毕的网络,保存模型可以方便后续直接使用。

再使用时,通过加载方式即可。

方法一、保存、加载整个模型

#mymodel.pkl是生成的文件名
#save
torch.save(model.dtate_dict(),'mymodel.pkl')  #
#load
model=torch.load('mymodel.pkl')    #

方法二、保存、加载模型的参数♥

只保存模型的参数,节省空间,在后续加载时可以去除特定层的参数

#save 
torch.save(model.state_dict(),'mymodel.pkl')
#load
model_object.load_state_dict(torch.load('mymodel.pkl'))   #model_object是新文件中新定义的模型名

例如:

(Pytorch)模型的保存与加载_第1张图片

方法三、把 别的模型中相同的网络参数 加载到 新的模型中

可以用已经训练好的网络参数作为自己模型的网络权重的初始化

如下,实现了从model_frommodel to的相同网络参数的拷贝:

def transfer_weights(model_from, model_to):
    wf = copy.deepcopy(model_from.state_dict())  #对model_from 中的模型参数的深层拷贝
    wt = model_to.state_dict()  #获取model_to模型参数
    #for循环的目的是让wf扩充后的结构跟wt一样,即保留model_from的模型参数,又将结构扩充到和model_to的一样。
    for k in wt.keys() :  #若在model_to中出现的网络结构,但在model_from中没有出现,则拷贝一份给wf.
        if (not k in wf)):      
            wf[k] = wt[k]
    model_to.load_state_dict(wf)  #load_state_dict函数加载想要的模型参数到目标模型model_to中

 

你可能感兴趣的:(pytorch,深度学习,神经网络)