pytorch 入门 修改学习率

知识点1、访问optimizer的参数,并修改
知识点2、多个模型下修改参数
知识点3、修改指定epoch下的参数

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.datasets import CIFAR10
from jc_utils import resnet
from torchvision import transforms as tfs
from datetime import datetime

知识点1
optimizer.param_groups[0][‘lr’] 如果把模型分为好几块,【0】就代表第一块,一般只有一块,[‘lr’]这里是字典 通过这种方式可以访问到optimizer内部的参数,从而做到更改

net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, weight_decay=1e-4)
print('learning rate:{}'.format(optimizer.param_groups[0]['lr']))
print('weight decay:{}'.format(optimizer.param_groups[0]['weight_decay']))
optimizer.param_groups[0]['lr'] = 1e-5
print('learning rate:{}'.format(optimizer.param_groups[0]['lr']))

知识点2
如果有多层,可以for一下,全部修改

for param_group in optimizer.param_groups:
    param_group['lr'] = 1e-1

知识点3
修改指定epoch的学习率

def get_acc(output, label):
    total = output.shape[0]
    _, pred_label = output.max(1)
    num_correct = (pred_label == label).sum().data.item()
    return num_correct / total



def set_learning_rate(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def train_tf(x):
    im_aug = tfs.Compose([
        tfs.Resize(120),
        tfs.RandomHorizontalFlip(),
        tfs.RandomCrop(96),
        tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
        tfs.ToTensor(),
        tfs.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
    ])
    x = im_aug(x)
    return x

def test_tf(x):
    im_aug = tfs.Compose([
        tfs.Resize(96),
        tfs.ToTensor(),
        tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    x = im_aug(x)
    return x

train_set = CIFAR10('./data', train=True, transform=train_tf)
train_data = torch.utils.data.DataLoader(train_set, batch_size=256, shuffle=True, num_workers=4)
valid_set = CIFAR10('./data', train=False, transform=test_tf)
valid_data = torch.utils.data.DataLoader(valid_set, batch_size=256, shuffle=False, num_workers=4)

net = resnet(3, 10)
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()


train_losses = []
valid_losses = []


if torch.cuda.is_available():
    net = net.cuda()
    print('of course cuda')
prev_time = datetime.now()
for epoch in range(30):
    if epoch == 20:         # ______________________________这里的方法可以修改指定epoch的学习率
        set_learning_rate(optimizer, 0.01)   # ________________
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
        if torch.cuda.is_available():
            im = Variable(im.cuda())
            label = Variable(label.cuda())
        else:
            im = Variable(im)
            label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.data.item()
        train_acc += get_acc(output, label)

    cur_time = datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
        valid_loss = 0
        valid_acc = 0
        net = net.eval()
        for im, label in valid_data:
            if torch.cuda.is_available():
                im = Variable(im.cuda())
                label = Variable(label.cuda())
            else:
                im = Variable(im)
                label = Variable(label)

            output = net(im)
            loss = criterion(output, label)
            valid_loss += loss.data.item()
            valid_acc += get_acc(output, label)
        epoch_str = (
                "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f,"
                % (epoch, train_loss / len(train_data), train_acc / len(train_data),
                   valid_loss / len(valid_data), valid_acc / len(valid_data))
        )
    prev_time = cur_time
    print(epoch_str + time_str)
    train_losses.append(train_loss / len(train_data))
    valid_losses.append(valid_loss / len(valid_data))

import matplotlib.pyplot as plt
plt.plot(train_losses, label='train')
plt.plot(valid_losses, label='valid')
plt.xlabel('epoch')
plt.legend(loc='best')
plt.show()

你可能感兴趣的:(习惯养成,Deep,Learning,pytorch)