找了一晚上warmup资料,有用的很少,基本都是互相转载,要不就是讲的很空泛,代码没有可使用的价值。但是最后我还是解决了,于是写一个warmup教程造福大家,这里抛砖引玉了。
GradualWarmupScheduler(optimizer, multiplier, total_epoch, after_scheduler)
参数解释:
optimizer:优化器
multiplier:当multiplier=1.0时,学习率lr从0开始增到base_lr为止,当multiplier大于1.0时,学习率lr从base_lr开始增到base_lr*multiplier为止。multiplier不能小于1.0。
【那么base_lr又是什么?】
【就是传入优化器的lr,例如optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay),base_lr就是learning_rate】
total_epoch:在total_epoch个epoch后达到目标学习率,也就是warmup持续的代数
after_scheduler:在经过total_epoch个epoch以后,所使用的学习率策略。
如果想了解更多细节,比如每个epoch的步长计算方式,可以查看源代码实现链接::pytorch-gradual-warmup-lr/blob/master/warmup_scheduler/scheduler.py
这里以我的代码为例,简单讲解下如何使用GradualWarmupScheduler。
如果没有warmup_scheduler包的话,需要安装:pip install warmup_scheduler
我先实现了optimizer,schedular_r(这里我定义的策略是如果测试准确率【mode=‘max’】连续三代不上升【patience=3】,则学习率变为原学习率的0.1倍【factor=0.1】),最后再实现schedular。
schedular的含义是经过10代【total_epoch=10】warm up,学习率由0.01(base_lr)逐渐上升至0.1【multiplier=10】,从第11代开始学习率策略将按照schedular_r进行衰减【after_scheduler=schedular_r】,也就是我第二点中介绍的。
值得注意的是schedular.step(metrics=test_acc)是在每个epoch进行迭代,且由于我后续使用的策略是ReduceLROnPlateau,所以这里需要传入一个参数metrics=test_acc。
from warmup_scheduler import GradualWarmupScheduler
def train(net, device, epochs, learning_rate, weight_decay):
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay)
# if loss do not change for 5 epochs, change lr*0.1
schedular_r = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3, verbose=True, eps=1e-5)
schedular = GradualWarmupScheduler(optimizer, multiplier=10, total_epoch=10, after_scheduler=schedular_r)
#schedular = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=1e-5, T_max=5)
initepoch = 0
loss = nn.CrossEntropyLoss()
best_test_acc = 0
for epoch in range(initepoch, epochs): # loop over the dataset multiple times
net.train()
timestart = time.time()
running_loss = 0.0
total = 0
correct = 0
#print(optimizer.param_groups[0]['lr'])
for i, data in tqdm(enumerate(train_iter)):
# get the inputs
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
l = loss(outputs, labels)
l.backward()
optimizer.step()
running_loss += l.item()
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
train_acc = 100.0 * correct / total
print('epoch %d, loss: %.4f,tran Acc: %.3f%%,time:%3f sec, lr: %.7f'
% (epoch+1, running_loss, train_acc, time.time() - timestart, optimizer.param_groups[0]['lr']))
print(schedular.last_epoch)
# test
net.eval()
total = 0
correct = 0
with torch.no_grad():
for data in tqdm(test_iter):
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
# print(outputs.shape)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
test_acc = 100.0 * correct / total
print('test Acc: %.3f%%' % (test_acc))
# if epoch > 30:
# torch.save(net.state_dict(), '/root/Desktop/cifar-100/checkpoint_512_512_100/' + str(test_acc) + '_Resnet18.pth')
if test_acc > best_test_acc:
print('find best! save at checkpoint/cnn_best.pth')
best_test_acc = test_acc
best_epoch = epoch
torch.save(net.state_dict(),
'/root/Desktop/exps/cifar100/model_512_100_IF10/best_' + str(best_test_acc) + '_' + str(train_acc) + '_resnet34.pth')
schedular.step(metrics=test_acc)
print('Finished Training')
print('best test acc epoch: %d' % epoch+1)