Pytorch冻结和解冻结预训练网络的finetune方法

目前Transformer在CV届已经大杀四方,各个赛事上都取得了SOAT的水平。很多人也想着将各种Transformer-based的backbone拿来用,但是很多backbone都是需要加载预训练参数才能使用,所以我们需要将网络进行迁移学习,才能拿来使用。我们通常会在网络刚开始训练的时候冻结除分类层之外的参数,在训练一到两轮再解封参数。下面我以swin-Transformer为例,介绍如何进行网络的fine-tune。

一、冻结网络参数

       首先需要使用Timm库加载Swin的结构以及参数,Timm库是Ross Wightman大神在github上开源的用于图像分类的库,包含各种各样的CNN以及Transformer-based的backbone,以及预训练参数模型,使用起来非常友好,真顶!例如加载Swin只需要一句话就好了。

from timm.models import create_model

Swin=create_model('swin_large_patch4_window7_224_in22k',pretrained=True)

因为我的分类任务数为14和预训练参数中的不匹配,并且我想要在训练前期固定除了分类层之外的所有参数,所以我加载网络后,会先去掉分类层,然后固定这部分的参数,接着再重新构建分类层。固定参数的代码如下:

for p in self.backbone.parameters():
    p.requires_grad = False

整体代码如下:

class classifer(nn.Module):
	def __init__(self,in_ch,num_classes):
		super().__init__()
		self.avgpool = nn.AdaptiveAvgPool1d(1)
		self.fc = nn.Linear(in_ch,num_classes)

	def forward(self, x):
		x = self.avgpool(x.transpose(1, 2))  # B C 1
		x = torch.flatten(x, 1)
		x = self.fc(x)
		return x

class Swin(nn.Module):
    def __init__(self):
        super().__init__() 
        #创建模型,并且加载预训练参数
        self.swin= create_model('swin_large_patch4_window7_224_in22k',pretrained=True)
        #整体模型的结构
        pretrained_dict = self.swin.state_dict()
        #去除模型的分类层
        self.backbone = nn.Sequential(*list(self.swin.children())[:-2])
        #去除分类层的模型架构
        model_dict = self.backbone.state_dict()

        # 将pretrained_dict里不属于model_dict的键剔除掉
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        # 更新现有的model_dict
        model_dict.update(pretrained_dict)
        # 加载我们真正需要的state_dict
        self.backbone.load_state_dict(model_dict)
        #屏蔽除分类层所有的参数
        for p in self.backbone.parameters():
            p.requires_grad = False
        #构建新的分类层
        self.head = classifer(1536, 14)

    def forward(self, x):
        x = self.backbone(x)
        x=self.head(x)
        return x

除了在模型中屏蔽参数外,还要再优化器中进行屏蔽,需要使用filter进行过滤,就是只优化梯度为True的参数,即分类层

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, cnn.parameters()), lr=1e-4, betas=(0.9, 0.999),weight_decay=1e-6)

二、解冻网络参数

在固定参数训练一轮以后,再解冻backbone部分的参数

if epoch ==1:
    for p in  Swin.backbone.parameters():
        p.requires_grad = True
        optimizer.add_param_group({'params': Swin.backbone.parameters()})

你可能感兴趣的:(pytorch,transformer,深度学习)