pytorch 学习率动态衰减

import torch

optimizer = torch.optim.Adam(deeplabv3.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
loss_func = torch.nn.CrossEntropyLoss()


for epoch in range(20):
    
    ...
    loss = loss_func(pre,label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()

    print("学习率:{},迭代次数:{}".format(optimizer.param_groups[0]['lr'], epoch))

你可能感兴趣的:(相关代码,大数据)