深度学习迁移学习 仅更新classifier pytorch

1. 问题描述
希望用预训练好的模型提取特征,仅更新classifier部分。但是和常规的不同的是,如果直接用model.load_state_dict(trained_state_dict)会存在key对不上的问题
 

2. 解决办法
在建模型的时候写好init_parameters()的函数,用训练好的模型来初始化,并且把权重的参数require_grad=False
其中最重要的就是:
 param0.data = g_dict['block' + name0].data

 def frozen_parameters(self, cfg, logger):
        import os
        model_state_file_g = os.path.join(cfg.OUTPUT_DIR, 'chromosome', 'GNet',
                                          'w32_256x256_adam_lr1e-3', 'checkpoint.pth')
        model_state_file_l = os.path.join(cfg.OUTPUT_DIR, 'chromosome', 'LNet',
                                          'w32_256x256_adam_lr1e-3', 'checkpoint.pth')
        g_trained = torch.load(model_state_file_g)['state_dict']
        l_trained = torch.load(model_state_file_l)['state_dict']

        g_dict = {k.replace('module.', ''): v.cpu() for k, v in g_trained.items()}
        for name0, param0 in self.bone_glocal.named_parameters():
            param0.requires_grad = False
            parts = name0.split('.')
            list_index = parts[0]
            if list_index == '0':
                param0.data = g_dict['conv3x3.' + parts[-1]].data
            elif list_index == '4':
                param0.data = g_dict['bn.' + parts[-1]].data
            else:
                param0.data = g_dict['block' + name0].data

        l_dict = {k.replace('module.', ''): v.cpu() for k, v in l_trained.items()}
        for name1, param1 in self.bone_local.named_parameters():
            param1.requires_grad = False
            parts = name1.split('.')
            list_index = parts[0]
            if list_index == '0':
                param1.data = l_dict['conv3x3.' + parts[-1]].data
            elif list_index == '4':
                param1.data = l_dict['bn.' + parts[-1]].data
            else:
                param1.data = l_dict['block' + name1].data
        logger.info('=> loading gnet, lnet model from {}, {}'.format(model_state_file_g, model_state_file_l))
        logger.info('gnet_epoch: {}'.format(torch.load(model_state_file_g)['epoch']))
        logger.info('lnet_epoch: {}'.format(torch.load(model_state_file_l)['epoch']))

在optimizer定义时:

optimizer = optim.SGD(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=cfg.TRAIN.LR,
        )

3. 踩过的坑 
(1)self.bone_local.named_parameters()返回的是tuple,不能修改
(2)module.named_parameters()获得参数后不能直接赋值,是tuple类型,但是para.data可以。(疑惑)
 

4. 其他:
(1)module.named_parameters()可以看到权重的名称和参数
(2)更多参考https://blog.csdn.net/qq_32998593/article/details/89343507

深度学习迁移学习 仅更新classifier pytorch_第1张图片

深度学习迁移学习 仅更新classifier pytorch_第2张图片

深度学习迁移学习 仅更新classifier pytorch_第3张图片

深度学习迁移学习 仅更新classifier pytorch_第4张图片

 

 

# nn.init._no_grad_fill_(param0, g_dict['conv3x3.' + parts[-1]])

你可能感兴趣的:(deep,learning)