pytorch:子模型参数冻结 + BN冻结

使用场景:需要完全冻结某部分的 weight 与 BN 层
  加载预训练模型时,如果只将 para.requires_grad = False ,并不能完全冻结模型的参数,因为模型中的 BN 层并不随 loss.backward() 与 optimizer.step() 来更新,而是在模型 forward 的过程中基于动量来更新,因此需要每个 forward 之前冻结 BN 层:
完整的冻结方式如下:


'''
一堆代码
'''

# 冻结BN
def freeze_bn(m):
    classname = ly.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval()


'''
一堆代码
'''
freeze_state_dict = torch.load(opt.loadckpt_freeze)
frozen_list = [k for k, _ in freeze_state_dict['state_dict'].items() if k in model_dict]
# 先冻结除了 BN 以外的参数
for param in model.named_parameters():
    if param[0] in frozen_list:		# 需要冻结的参数列表
        param[1].requires_grad = False

# 优化器优化的参数只包含需要梯度更新的参数        
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.lr, betas=(0.9,0.999))

'''
一堆代码
'''

for epoch in range(opt.epoch):
	model.train()
	optimizer.zero_grad()
	# 冻结BN
	model.apply(freeze_bn)
	# 前向传播
	output = model(input)
	loss = loss_F(gt, output)
	loss.backward()
	optimizer.step()

你可能感兴趣的:(pytorch踩坑日记,pytorch,深度学习,python)