优化器的作用:管理并更新模型中可学习参数的值,使得模型输出更接近真实标签。
管理:更新哪些参数
更新:根据一定的优化策略更新参数的值
为了避免一些意外情况的发生,每隔一定的epoch就保存一次网络训练的状态信息,从而可以在意外中断后继续训练。
添加参数后结果如下
网络的构建过程
weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))
optimizer = optim.SGD([weight], lr=0.1)
# ----------------------------------- step -----------------------------------
flag = 0
# flag = 1
if flag:
print("weight before step:{}".format(weight.data))
optimizer.step() # 修改lr=1 0.1观察结果
print("weight after step:{}".format(weight.data))
# ------------------------ zero_grad --------------------------------
# flag = 0
flag = 1
if flag:
print("weight before step:{}".format(weight.data))
optimizer.step() # 修改lr=1 0.1观察结果
print("weight after step:{}".format(weight.data))
print("weight in optimizer:{}\nweight in weight:{}\n".format(id(optimizer.param_groups[0]['params'][0]), id(weight)))
print("weight.grad is {}\n".format(weight.grad))
optimizer.zero_grad()
print("after optimizer.zero_grad(), weight.grad is\n{}".format(weight.grad))
# ----------------------------------- add_param_group -----------------------------------
# flag = 0
flag = 1
if flag:
print("optimizer.param_groups is\n{}".format(optimizer.param_groups))
w2 = torch.randn((3, 3), requires_grad=True)
optimizer.add_param_group({"params": w2, 'lr': 0.0001})
print("optimizer.param_groups is\n{}".format(optimizer.param_groups))
# ----------------------------------- state_dict -----------------------------------
flag = 0
# flag = 1
if flag:
optimizer = optim.SGD([weight], lr=0.1, momentum=0.9)
opt_state_dict = optimizer.state_dict()
print("state_dict before step:\n", opt_state_dict)
for i in range(10):
optimizer.step()
print("state_dict after step:\n", optimizer.state_dict())
torch.save(optimizer.state_dict(), os.path.join(BASE_DIR, "optimizer_state_dict.pkl"))
# -----------------------------------load state_dict -----------------------------------
# flag = 0
flag = 1
if flag:
optimizer = optim.SGD([weight], lr=0.1, momentum=0.9)
state_dict = torch.load(os.path.join(BASE_DIR, "optimizer_state_dict.pkl"))
print("state_dict before load state:\n", optimizer.state_dict())
optimizer.load_state_dict(state_dict)
print("state_dict after load state:\n", optimizer.state_dict())
假设损失函数如下图所示为y=4*w^2(w为网络的权重)
权重的更新公式为
当学习率为0.5时,由于更新步幅太大造成损失的爆炸
修改学习率为0.2后,损失函数可以稳步的下降,因此学习率的选择非常重要。
然而面对不同的损失函数和网络,如何选择学习率是一个问题,一般的策略为采取较小的学习率,用时间来换取精度
从上图可以看出,学习率为0.01时,虽然损失下降的比较慢,但最终也能达到不错的效果。
# -*- coding:utf-8 -*-
"""
@file name : learning_rate.py
# @author : TingsongYu https://github.com/TingsongYu
@date : 2019-10-16
@brief : 梯度下降的学习率演示
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1)
def func(x_t):
"""
y = (2x)^2 = 4*x^2 dy/dx = 8x
"""
return torch.pow(2*x_t, 2)
# init
x = torch.tensor([2.], requires_grad=True)
# ------------------------------ plot data ------------------------------
flag = 0
# flag = 1
if flag:
x_t = torch.linspace(-3, 3, 100)
y = func(x_t)
plt.plot(x_t.numpy(), y.numpy(), label="y = 4*x^2")
plt.grid()
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.show()
# ------------------------------ gradient descent ------------------------------
flag = 0
# flag = 1
if flag:
iter_rec, loss_rec, x_rec = list(), list(), list()
lr = 0.2 # /1. /.5 /.2 /.1 /.125
max_iteration = 20 # /1. 4 /.5 4 /.2 20 200
for i in range(max_iteration):
y = func(x)
y.backward()
print("Iter:{}, X:{:8}, X.grad:{:8}, loss:{:10}".format(
i, x.detach().numpy()[0], x.grad.detach().numpy()[0], y.item()))
x_rec.append(x.item())
x.data.sub_(lr * x.grad) # x -= x.grad 数学表达式意义: x = x - x.grad # 0.5 0.2 0.1 0.125
x.grad.zero_()
iter_rec.append(i)
loss_rec.append(y)
plt.subplot(121).plot(iter_rec, loss_rec, '-ro')
plt.xlabel("Iteration")
plt.ylabel("Loss value")
x_t = torch.linspace(-3, 3, 100)
y = func(x_t)
plt.subplot(122).plot(x_t.numpy(), y.numpy(), label="y = 4*x^2")
plt.grid()
y_rec = [func(torch.tensor(i)).item() for i in x_rec]
plt.subplot(122).plot(x_rec, y_rec, '-ro')
plt.legend()
plt.show()
# ------------------------------ multi learning rate ------------------------------
flag = 0
# flag = 1
if flag:
iteration = 100
num_lr = 10
lr_min, lr_max = 0.01, 0.2 # .5 .3 .2
lr_list = np.linspace(lr_min, lr_max, num=num_lr).tolist()
loss_rec = [[] for l in range(len(lr_list))]
iter_rec = list()
for i, lr in enumerate(lr_list):
x = torch.tensor([2.], requires_grad=True)
for iter in range(iteration):
y = func(x)
y.backward()
x.data.sub_(lr * x.grad) # x.data -= x.grad
x.grad.zero_()
loss_rec[i].append(y.item())
for i, loss_r in enumerate(loss_rec):
plt.plot(range(len(loss_r)), loss_r, label="LR: {}".format(lr_list[i]))
plt.legend() #
plt.xlabel('Iterations')
plt.ylabel('Loss value')
plt.show()
指数加权平均
各个时刻的权重如下图所示,离当前时刻越远,权重越小,呈指数衰减
参数beta对权重的影响下图所示,可以看出beta的值越小,对过去数据的记忆时间越短。
pytorch中的带动量SGD更新公式