Pytorch冻结预训练权重(特征提取与BN层)

1. 读取预训练权重

pre_weights = torch.load(model_weights_path, map_location=device)

2. 读取预训练权重中与现有模型参数设置相同层的权重,可适用于修改了分类或某些层通道数的情况

net = yourmodel()
pre_dict = {k: v for k, v in pre_weights.items() 
    if net.state_dict()[k].numel() == v.numel()}
# strict = False 表示仅读取可以匹配的权重
missing_keys, unexpected_keys = net.load_state_dict(pre_dict, strict = False)

3. 冻结特征提取层预训练权重

for params in net.features.parameters():
    params.requires_grad = False

4. 由于BN层参数是由各通道值计算得出,在forward中自动实现,而不是通过梯度计算和反向传播更新,需额外冻结BN层权重

def freeze_bn(ly):
    classname = ly.__class__.__name__
    if classname.find('BatchNorm') != -1:
        ly.eval()
net.apply(freeze_bn)

5. 相关链接

1. pytorch中的BN层简介_lpj822的专栏-CSDN博客

2. model.load_state_dict(state_dict, strict=False)_t20134297的博客-CSDN博客

你可能感兴趣的:(Pytorch,pytorch,迁移学习)