复现一个模型的时候,需要在训练的后半阶段将一个INN网络的前向推理过程和逆向推理过程解耦,即
为了实现上述的2-stage训练策略,我们需要实现以下三个步骤:
1. 在第二阶段训练时,需要增添一个与stage1的INN相同网络结构的模块,作为decoder(stage1训练的网络作为encoder);
class MyHiNet(nn.Module):
def __init__(self, in_1=3, in_2=3, dwt_opt=True):
super(PRIS, self).__init__()
self.inbs_couple = Hinet(in_1=in_1, in_2=in_2, dwt_opt=dwt_opt)
self.inbs_decouple = Hinet(in_1=in_1, in_2=in_2, dwt_opt=dwt_opt)
def forward(self, x, rev=False, decouple=False):
if not decouple:
if not rev:
out = self.inbs_couple(x)
else:
out = self.inbs_couple(x, rev=True)
else:
if not rev:
out = self.inbs_couple(x)
else:
out = self.inbs_decouple(x, rev=True)
return out
#################
# forward: #
#################
output = net(input)
#################
# backward: #
#################
if stage == 0:
output_rev = net(input_rev, rev=True, decouple=False)
elif stage == 1:
output_rev = net(input_rev, rev=True, decouple=True)
2.在第二阶段训练开始时,需要实现encoder至decoder的参数迁移。
# load pretrained model
if load_path != '':
net = load(net, load_path)
if stage == 1:
# 参数迁移(inbs_couple->inbs_decouple)
inbs_state_dict = net.module.inbs_couple.state_dict()
net.module.inbs_decouple.load_state_dict(inbs_state_dict, strict=True)
print(f'load from {load_path}')
参数迁移测试:
# #############TEST 参数迁移##################### #
net = MyHiNet()
print(f'load from {load_path_couple}')
net = load(net, load_path_couple)
# 参数迁移:inbs_couple->inbs_decouple
inbs_state_dict = net.module.inbs_couple.state_dict()
net.module.inbs_decouple.load_state_dict(inbs_state_dict, strict=True)
# inbs_couple参数
print('inbs_couple:\n')
for param_tensor in inbs_state_dict:
print(param_tensor, "\t", inbs_state_dict[param_tensor])
# inbs_decouple 迁移后参数
print('inbs_decouple:\n')
inbs_decouple_state_dict = net.module.inbs_decouple.state_dict()
for param_tensor in inbs_decouple_state_dict:
print(param_tensor, "\t", inbs_decouple_state_dict[param_tensor])
迁移后inbs_decouple模块的网络参数与inbs_couple模块一致:
学习参考:Pytorch中模型之间的参数共享_pytorch 参数共享-CSDN博客https://blog.csdn.net/cyj972628089/article/details/127325735
3. 第二阶段训练,实现encoder锁定,decoder微调
# 通过参数名称 在训练的不同阶段控制哪些部分的参数需要更新
if stage == 0:
lr = c.lr
for name, para in net.named_parameters():
if 'inbs_couple' in name:
para.requires_grad = True
elif 'inbs_decouple' in name:
para.requires_grad = False
elif stage == 1:
lr = c.lr
for name, para in net.named_parameters():
if 'inbs_couple' in name:
para.requires_grad = False
elif 'inbs_decouple' in name:
para.requires_grad = True