Pytorch 使用预训练模型

上一篇讲了如何载入模型,这里写一下如何使用载入的模型初始化新网络的部分层:
我的理解在于,在pytorch中,模型的参数是按照字典的形式存储的,key为该层的名称,相应的value是这层的参数,理解了之后,其实更新一个新的网络的参数,也就是用一个已经存在的字典(也就是预训练的模型的参数)来更新新的字典(新的模型的参数):

  1. 网络结构的定义
# DenseNet这个类就是网络denesnet的结构的定义,这里参考了pytorch 里面models的源码
class DenseNet(nn.Module):
    ... ...(此处网络结构的定义省略
class fw_DenseNet(nn.Module):
    ... ...(这个是我修改后网络结构)
  1. 获取网络参数:
'''预训练的模型'''
net = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24))
net.load_state_dict(torch.load('/home/wei.fan/.torch/models/densenet161-17b70270.pth'))
net_dict = net.state_dict()  #获取预训练模型的参数
''' 自定义的网络模型'''
net1 = fw_DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24))
net1_dict = net1.state_dict() #获取参数,但其实没有参数(因为没有训练)
  1. 使用预训练的模型的参数,更新自定义模型的参数
net1_dict = {k: v for k, v in net_dict.items() if k in net1_dict} #把两个模型中名称不同的层去掉
net1_dict.update(net1_dict) #使用预训练模型更新新模型的参数
net1.load_state_dict(net1_dict) #更新模型

参考了这里:How to load part of pre trained model?
其实核心代码只有下面三句,但是因为pytorch的内部机制不清楚,所以搞了蛮久才弄懂,这里贴出来让后来者少踩几个坑吧。

# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict) 
# 3. load the new state dict
model.load_state_dict(pretrained_dict)

你可能感兴趣的:(Pytorch 使用预训练模型)