分阶段解耦训练 模型之间的参数迁移

        复现一个模型的时候,需要在训练的后半阶段将一个INN网络的前向推理过程和逆向推理过程解耦,即

  • stage1:正常训练INN,encoder和decoder分别为INN的forward 和backward process,共享一套网络参数;
  • stage2:解耦INN,encoder锁定,decoder微调(使INN适应任务中的不可逆过程)。

 为了实现上述的2-stage训练策略,我们需要实现以下三个步骤:

1. 在第二阶段训练时,需要增添一个与stage1的INN相同网络结构的模块,作为decoder(stage1训练的网络作为encoder);

class MyHiNet(nn.Module):
    def __init__(self, in_1=3, in_2=3, dwt_opt=True):
        super(PRIS, self).__init__()
        self.inbs_couple = Hinet(in_1=in_1, in_2=in_2, dwt_opt=dwt_opt)
        self.inbs_decouple = Hinet(in_1=in_1, in_2=in_2, dwt_opt=dwt_opt)

    def forward(self, x, rev=False, decouple=False):
        if not decouple:
            if not rev:
                out = self.inbs_couple(x)
            else:
                out = self.inbs_couple(x, rev=True)
        else:
            if not rev:
                out = self.inbs_couple(x)
            else:
                out = self.inbs_decouple(x, rev=True)

        return out
    #################
    #    forward:   #
    #################
    output = net(input)        

    #################
    #   backward:   #
    #################
    if stage == 0:
        output_rev = net(input_rev, rev=True, decouple=False)
    elif stage == 1:
        output_rev = net(input_rev, rev=True, decouple=True)

2.在第二阶段训练开始时,需要实现encoder至decoder的参数迁移。

    # load pretrained model
    if load_path != '':
        net = load(net, load_path)
        if stage == 1:
            # 参数迁移(inbs_couple->inbs_decouple)
            inbs_state_dict = net.module.inbs_couple.state_dict()
            net.module.inbs_decouple.load_state_dict(inbs_state_dict, strict=True)
        print(f'load from {load_path}')

参数迁移测试:

# #############TEST 参数迁移##################### #
net = MyHiNet()
print(f'load from {load_path_couple}')
net = load(net, load_path_couple)

# 参数迁移:inbs_couple->inbs_decouple
inbs_state_dict = net.module.inbs_couple.state_dict()
net.module.inbs_decouple.load_state_dict(inbs_state_dict, strict=True)

# inbs_couple参数
print('inbs_couple:\n')
for param_tensor in inbs_state_dict:
     print(param_tensor, "\t", inbs_state_dict[param_tensor])

# inbs_decouple 迁移后参数
print('inbs_decouple:\n')
inbs_decouple_state_dict = net.module.inbs_decouple.state_dict()
for param_tensor in inbs_decouple_state_dict:
    print(param_tensor, "\t", inbs_decouple_state_dict[param_tensor])

迁移后inbs_decouple模块的网络参数与inbs_couple模块一致:
分阶段解耦训练 模型之间的参数迁移_第1张图片

学习参考:Pytorch中模型之间的参数共享_pytorch 参数共享-CSDN博客icon-default.png?t=N7T8https://blog.csdn.net/cyj972628089/article/details/127325735

3. 第二阶段训练,实现encoder锁定,decoder微调

  # 通过参数名称 在训练的不同阶段控制哪些部分的参数需要更新
    if stage == 0:
        lr = c.lr
        for name, para in net.named_parameters():
            if 'inbs_couple' in name:
                para.requires_grad = True
            elif 'inbs_decouple' in name:
                para.requires_grad = False
    elif stage == 1:
        lr = c.lr
        for name, para in net.named_parameters():
            if 'inbs_couple' in name:
                para.requires_grad = False
            elif 'inbs_decouple' in name:
                para.requires_grad = True

你可能感兴趣的:(深度学习,python,机器学习,开发语言)