【介绍+代码实现】使用GradualWarmupScheduler进行学习率预热

找了一晚上warmup资料,有用的很少,基本都是互相转载,要不就是讲的很空泛,代码没有可使用的价值。但是最后我还是解决了,于是写一个warmup教程造福大家,这里抛砖引玉了。

一、介绍GradualWarmupScheduler

GradualWarmupScheduler(optimizer, multiplier, total_epoch, after_scheduler)

参数解释
optimizer:优化器
multiplier:当multiplier=1.0时,学习率lr从0开始增到base_lr为止,当multiplier大于1.0时,学习率lr从base_lr开始增到base_lr*multiplier为止。multiplier不能小于1.0。
【那么base_lr又是什么?】
【就是传入优化器的lr,例如optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay),base_lr就是learning_rate】
total_epoch:在total_epoch个epoch后达到目标学习率,也就是warmup持续的代数
after_scheduler:在经过total_epoch个epoch以后,所使用的学习率策略。

如果想了解更多细节,比如每个epoch的步长计算方式,可以查看源代码实现链接::pytorch-gradual-warmup-lr/blob/master/warmup_scheduler/scheduler.py

二、如何使用GradualWarmupScheduler

这里以我的代码为例,简单讲解下如何使用GradualWarmupScheduler。

  1. 如果没有warmup_scheduler包的话,需要安装:pip install warmup_scheduler

  2. 我先实现了optimizer,schedular_r(这里我定义的策略是如果测试准确率【mode=‘max’】连续三代不上升【patience=3】,则学习率变为原学习率的0.1倍【factor=0.1】),最后再实现schedular。

  3. schedular的含义是经过10代【total_epoch=10】warm up,学习率由0.01(base_lr)逐渐上升至0.1【multiplier=10】,从第11代开始学习率策略将按照schedular_r进行衰减【after_scheduler=schedular_r】,也就是我第二点中介绍的。

  4. 值得注意的是schedular.step(metrics=test_acc)是在每个epoch进行迭代,且由于我后续使用的策略是ReduceLROnPlateau,所以这里需要传入一个参数metrics=test_acc。

from warmup_scheduler import GradualWarmupScheduler

def train(net, device, epochs, learning_rate, weight_decay):
    optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay)
    # if loss do not change for 5 epochs, change lr*0.1
    schedular_r = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3, verbose=True, eps=1e-5)
    schedular = GradualWarmupScheduler(optimizer, multiplier=10, total_epoch=10, after_scheduler=schedular_r)
    #schedular = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=1e-5, T_max=5)
    initepoch = 0
    loss = nn.CrossEntropyLoss()
    best_test_acc = 0

    for epoch in range(initepoch, epochs):  # loop over the dataset multiple times

        net.train()

        timestart = time.time()

        running_loss = 0.0
        total = 0
        correct = 0
        #print(optimizer.param_groups[0]['lr'])
        for i, data in tqdm(enumerate(train_iter)):
            # get the inputs
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = net(inputs)
            l = loss(outputs, labels)
            l.backward()
            optimizer.step()
            running_loss += l.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            train_acc = 100.0 * correct / total

        print('epoch %d, loss: %.4f,tran Acc: %.3f%%,time:%3f sec, lr: %.7f'
              % (epoch+1, running_loss, train_acc, time.time() - timestart, optimizer.param_groups[0]['lr']))
        print(schedular.last_epoch)

        # test

        net.eval()

        total = 0
        correct = 0
        with torch.no_grad():
            for data in tqdm(test_iter):
                images, labels = data
                images, labels = images.to(device), labels.to(device)
                outputs = net(images)
                # print(outputs.shape)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            test_acc = 100.0 * correct / total
            print('test Acc: %.3f%%' % (test_acc))
            # if epoch > 30:
            #     torch.save(net.state_dict(), '/root/Desktop/cifar-100/checkpoint_512_512_100/' + str(test_acc) + '_Resnet18.pth')
            if test_acc > best_test_acc:
                print('find best! save at checkpoint/cnn_best.pth')
                best_test_acc = test_acc
                best_epoch = epoch
                torch.save(net.state_dict(),
                           '/root/Desktop/exps/cifar100/model_512_100_IF10/best_' + str(best_test_acc) + '_' + str(train_acc) + '_resnet34.pth')

        schedular.step(metrics=test_acc)


    print('Finished Training')
    print('best test acc epoch: %d' % epoch+1)

你可能感兴趣的:(pytorch,python)