改动模型后,加载部分预训练权重文件

加载部分预训练权重文件

最近在做姿态估计相关内容,需要将Hrnet模型修改,Hrnet是基于coco数据集训练的,coco数据集是17个关节点,而我需要的是15个关节点,在将数据集标好训练之后,发现由于数据量比较小,模型能够很快收敛,但是泛化性能极差,于是,就想着把之前的coco预训练权重文件拿出来一部分,对我自己的模型进行训练,果不其然,在使用部分预训练权重文件进行训练后,模型的泛化性有了很大的改善,现在分享给大家。

1首先我们需要明确权重文件的类型是什么:我们在使用pytorch进行模型训练的时候,最后的权重文件实际上是一个字典, 只不过是一个有序字典OrderedDict类 ,关于这个类的各种操作,请参考这篇博客, 里面已经说的很详细了OrderedDict

2在明确权重文件其实就是一个字典类的时候,那么我们就能了解,权重文件其实就是key+value,所谓key就是每一层的关键字,而value就是每一层的矩阵数据,下面以一份权重文件为例:

改动模型后,加载部分预训练权重文件_第1张图片

我们在加载预训练权重文件之后,发现就是一个字典,并且是一个有序字典,那么同样,我们可以打印出字典的关键字:

改动模型后,加载部分预训练权重文件_第2张图片

在了解上述操作过程以后,那么如果我们想要加载部分预训练权重文件就很简单啦。

首先,我们需要将我们实例化删改后的模型:

model = YOUR_changed_model(**)

其次,加载你删改后模型的state_dict()

model_state_dict= model.state_dict()

同样,model_state_dict()也是一个字典文件,因为我们已经改变了模型,但是我们改变的只是模型的一部分,换句话说,改变的只是权重文件字典中的某些keys或者values,而我们加载的部分权重文件其实就是在原来权重文件中没有修改的。以我的模型为例,我只是改变了模型的最后的全连接层,本来最后一维是17维,我需要的是15维,那么也就是说出去最后一层的预训练权重文件,我都是可以使用的,并且最后一维我可以使用预训练权重文件的前15维,因此,修改如下:

for i, (k, v) in enumerate(model_state.items()):
    if i < 1752:
        model_state[k] = pretrained_weights[k] 
    else:
        model_state[k] = pretrained_weights[k][:-2]
torch.save(model_state, 'best.pt')  # 保存权重文件

最后由于我的模型修改比较简单,所以,调取预训练权重文件也比较容易,但是核心思想是一致的,就是把权重文件看作是一个字典,在我们新的模型中添加原来权重文件中存在的key以及value。

你可能感兴趣的:(姿态估计,机器学习,深度学习)