微调(linear probing和 full finetune)技巧

应用场景:图像分类 (使用CIFAR10作为测试数据集,ResNet18作为Backbone网络)

数据预处理 (resize_val = 224, n_holes = 1, length在原有的size=32时候为16)

# Image Preprocessing
normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        std = [x / 255.0 for x in [63.0, 62.1, 66.7]])

train_transform = transforms.Compose([])
if args.resize_val != -1:
    train_transform.transforms.append(transforms.Resize((args.resize_val,args.resize_val)))

train_transform.transforms.append(transforms.RandomCrop(args.resize_val, padding=4))
train_transform.transforms.append(transforms.RandomHorizontalFlip())
train_transform.transforms.append(transforms.ToTensor())
train_transform.transforms.append(normalize)
train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))


test_transform = transforms.Compose([
    transforms.Resize((args.resize_val,args.resize_val)),
    transforms.ToTensor(),
    normalize])

训练优化器和 lr decay方式 (lr =0.01,wd = 5e-4)

criterion = nn.CrossEntropyLoss()
cnn_optimizer = torch.optim.SGD(cnn.parameters(), lr=args.learning_rate,
                                momentum=0.9, nesterov=True, weight_decay=args.wd)


scheduler = MultiStepLR(cnn_optimizer, milestones=[60, 120, 160], gamma=0.2)

(1)从头训练 (from scratch): 直接按照上面的代码进行训练即可

(2)finetune all (没有冻结任何层)

    cnn.load_state_dict(torch.load(path))
    inchannel = cnn.fc.in_features
    cnn.fc = nn.Linear(inchannel, args.num_class)

注意:

a) 在load_state_dict时候,应该使用默认的 1000 类(因为我用的imagenet1k的预训练参数),加载完后才改变成指定的 args.num_class

b) path指的是预训练的参数的地址:可以使用torchvision上提供好的(为了方便我也粘贴一下)

Resnet18: https://download.pytorch.org/models/resnet18-f37072fd.pth

Resnet34: https://download.pytorch.org/models/resnet34-b627a593.pth

Resnet50: https://download.pytorch.org/models/resnet50-0676ba61.pth

 (3) linear probing (冻结了浅层)

    cnn.load_state_dict(torch.load(path))
    inchannel = cnn.fc.in_features
    cnn.fc = nn.Linear(inchannel, args.num_class)
    for k, v in cnn.named_parameters():
        if k.split('.')[0] != 'fc':
            v.requires_grad = False 

但是linear层的微调性能远不如全部finetune,这几天分析完原因补一下

你可能感兴趣的:(正式开始炼丹,Finetune,Linear,probing)