pytorch仅加载网络中相同部分的预训练权重

pytorch仅加载网络中相同部分的预训练权重

在复现论文时,由于对网络结构进行了修改,部分层结构无法加载之前训练的权重,这里记录一下解决方案。

问题回顾

自定义模型时对部分网络结构进行了调整。

model = MyModel()
# 通过state_dict()获取到参数列表
model.state_dict().keys()
>>>
odict_keys([
	'features.0.weight', 
	'features.1.running_mean',
	'features.1.running_var',
	'features.1.num_batches_tracked', 
	...
	])

而对于权重文件:

pretrained_weight = torch.load(file_path)
pretrained_weight.state_dict().keys()
>>>
odict_keys([
	'features.0.weight', 
	'features.1.weight',
	'features.1.num_batches_tracked',
	...
	])

这里给出的例子中,我们定义的网络与预训练权重仅第一层一样。
也就是说,我们只需要把对应的层参数加载即可。

解决方法

通过state_dict()方法拿到权重文件和模型的参数字典,将相同层级的参数进行导入即可。

这里给出自己写的一个小demo:

def load_weight(model, path):
    pretrained_weight = torch.load(path)
    for key in pretrained_weight.keys():
        model.state_dict()[key] = pretrained_weight[key]
    return model

你可能感兴趣的:(pytorch,deep,learning)