pytorch学习笔记:优化器

1、优化器的概念

优化器的作用:管理更新模型中可学习参数的值,使得模型输出更接近真实标签。

管理:更新哪些参数

更新:根据一定的优化策略更新参数的值

pytorch学习笔记:优化器_第1张图片

 2、基本属性

pytorch学习笔记:优化器_第2张图片

pytorch学习笔记:优化器_第3张图片

pytorch学习笔记:优化器_第4张图片

pytorch学习笔记:优化器_第5张图片

pytorch学习笔记:优化器_第6张图片

为了避免一些意外情况的发生,每隔一定的epoch就保存一次网络训练的状态信息,从而可以在意外中断后继续训练。

2.1、单步调试代码观察优化器建立过程

  • 首先运行到断点出step into

  • 2.进行到SGD初始化函数

pytorch学习笔记:优化器_第7张图片

  • 3.运行到64行step into进入父类optimizer继续初始化

pytorch学习笔记:优化器_第8张图片

pytorch学习笔记:优化器_第9张图片

pytorch学习笔记:优化器_第10张图片

 添加参数后结果如下

pytorch学习笔记:优化器_第11张图片

网络的构建过程

pytorch学习笔记:优化器_第12张图片

pytorch学习笔记:优化器_第13张图片

pytorch学习笔记:优化器_第14张图片

  • 4.step out跳出并完成优化器的创建

pytorch学习笔记:优化器_第15张图片

pytorch学习笔记:优化器_第16张图片

  • 5.清空梯度
  • 6.更新参数

pytorch学习笔记:优化器_第17张图片

2.2、优化器基本方法的使用

pytorch学习笔记:优化器_第18张图片

  • step():一步更新
weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))

optimizer = optim.SGD([weight], lr=0.1)

# ----------------------------------- step -----------------------------------
flag = 0
# flag = 1
if flag:
    print("weight before step:{}".format(weight.data))
    optimizer.step()        # 修改lr=1 0.1观察结果
    print("weight after step:{}".format(weight.data))

pytorch学习笔记:优化器_第19张图片

  •  zero_gard():清空梯度
# ------------------------ zero_grad --------------------------------
# flag = 0
flag = 1
if flag:

    print("weight before step:{}".format(weight.data))
    optimizer.step()        # 修改lr=1 0.1观察结果
    print("weight after step:{}".format(weight.data))

    print("weight in optimizer:{}\nweight in weight:{}\n".format(id(optimizer.param_groups[0]['params'][0]), id(weight)))

    print("weight.grad is {}\n".format(weight.grad))
    optimizer.zero_grad()
    print("after optimizer.zero_grad(), weight.grad is\n{}".format(weight.grad))

pytorch学习笔记:优化器_第20张图片

  • add_param_group():添加参数组

pytorch学习笔记:优化器_第21张图片

# ----------------------------------- add_param_group -----------------------------------
# flag = 0
flag = 1
if flag:
    print("optimizer.param_groups is\n{}".format(optimizer.param_groups))

    w2 = torch.randn((3, 3), requires_grad=True)

    optimizer.add_param_group({"params": w2, 'lr': 0.0001})

    print("optimizer.param_groups is\n{}".format(optimizer.param_groups))
  •  state_dict():获取优化器当前的状态字典

pytorch学习笔记:优化器_第22张图片

# ----------------------------------- state_dict -----------------------------------
flag = 0
# flag = 1
if flag:

    optimizer = optim.SGD([weight], lr=0.1, momentum=0.9)
    opt_state_dict = optimizer.state_dict()

    print("state_dict before step:\n", opt_state_dict)

    for i in range(10):
        optimizer.step()

    print("state_dict after step:\n", optimizer.state_dict())

    torch.save(optimizer.state_dict(), os.path.join(BASE_DIR, "optimizer_state_dict.pkl"))
  •  load_state_dict()加载保存的状态字典

pytorch学习笔记:优化器_第23张图片

# -----------------------------------load state_dict -----------------------------------
# flag = 0
flag = 1
if flag:

    optimizer = optim.SGD([weight], lr=0.1, momentum=0.9)
    state_dict = torch.load(os.path.join(BASE_DIR, "optimizer_state_dict.pkl"))

    print("state_dict before load state:\n", optimizer.state_dict())
    optimizer.load_state_dict(state_dict)
    print("state_dict after load state:\n", optimizer.state_dict())

2.3、学习率对训练的影响

假设损失函数如下图所示为y=4*w^2(w为网络的权重)

pytorch学习笔记:优化器_第24张图片

权重的更新公式为

pytorch学习笔记:优化器_第25张图片

当学习率为0.5时,由于更新步幅太大造成损失的爆炸

pytorch学习笔记:优化器_第26张图片

修改学习率为0.2后,损失函数可以稳步的下降,因此学习率的选择非常重要。

pytorch学习笔记:优化器_第27张图片

然而面对不同的损失函数和网络,如何选择学习率是一个问题,一般的策略为采取较小的学习率,用时间来换取精度

pytorch学习笔记:优化器_第28张图片

从上图可以看出,学习率为0.01时,虽然损失下降的比较慢,但最终也能达到不错的效果。

# -*- coding:utf-8 -*-
"""
@file name  : learning_rate.py
# @author     : TingsongYu https://github.com/TingsongYu
@date       : 2019-10-16
@brief      : 梯度下降的学习率演示
"""
import torch
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1)


def func(x_t):
    """
    y = (2x)^2 = 4*x^2      dy/dx = 8x
    """
    return torch.pow(2*x_t, 2)


# init
x = torch.tensor([2.], requires_grad=True)


# ------------------------------ plot data ------------------------------
flag = 0
# flag = 1
if flag:

    x_t = torch.linspace(-3, 3, 100)
    y = func(x_t)
    plt.plot(x_t.numpy(), y.numpy(), label="y = 4*x^2")
    plt.grid()
    plt.xlabel("x")
    plt.ylabel("y")
    plt.legend()
    plt.show()


# ------------------------------ gradient descent ------------------------------
flag = 0
# flag = 1
if flag:
    iter_rec, loss_rec, x_rec = list(), list(), list()

    lr = 0.2    # /1. /.5 /.2 /.1 /.125
    max_iteration = 20   # /1. 4     /.5 4   /.2 20 200

    for i in range(max_iteration):

        y = func(x)
        y.backward()

        print("Iter:{}, X:{:8}, X.grad:{:8}, loss:{:10}".format(
            i, x.detach().numpy()[0], x.grad.detach().numpy()[0], y.item()))

        x_rec.append(x.item())

        x.data.sub_(lr * x.grad)    # x -= x.grad  数学表达式意义:  x = x - x.grad    # 0.5 0.2 0.1 0.125
        x.grad.zero_()

        iter_rec.append(i)
        loss_rec.append(y)

    plt.subplot(121).plot(iter_rec, loss_rec, '-ro')
    plt.xlabel("Iteration")
    plt.ylabel("Loss value")

    x_t = torch.linspace(-3, 3, 100)
    y = func(x_t)
    plt.subplot(122).plot(x_t.numpy(), y.numpy(), label="y = 4*x^2")
    plt.grid()
    y_rec = [func(torch.tensor(i)).item() for i in x_rec]
    plt.subplot(122).plot(x_rec, y_rec, '-ro')
    plt.legend()
    plt.show()

# ------------------------------ multi learning rate ------------------------------

flag = 0
# flag = 1
if flag:
    iteration = 100
    num_lr = 10
    lr_min, lr_max = 0.01, 0.2  # .5 .3 .2

    lr_list = np.linspace(lr_min, lr_max, num=num_lr).tolist()
    loss_rec = [[] for l in range(len(lr_list))]
    iter_rec = list()

    for i, lr in enumerate(lr_list):
        x = torch.tensor([2.], requires_grad=True)
        for iter in range(iteration):

            y = func(x)
            y.backward()
            x.data.sub_(lr * x.grad)  # x.data -= x.grad
            x.grad.zero_()

            loss_rec[i].append(y.item())

    for i, loss_r in enumerate(loss_rec):
        plt.plot(range(len(loss_r)), loss_r, label="LR: {}".format(lr_list[i]))
        plt.legend() #
        plt.xlabel('Iterations')
        plt.ylabel('Loss value')
        plt.show()

2.4、动量

指数加权平均

pytorch学习笔记:优化器_第29张图片

各个时刻的权重如下图所示,离当前时刻越远,权重越小,呈指数衰减

pytorch学习笔记:优化器_第30张图片

参数beta对权重的影响下图所示,可以看出beta的值越小,对过去数据的记忆时间越短

pytorch学习笔记:优化器_第31张图片

pytorch中的带动量SGD更新公式

pytorch学习笔记:优化器_第32张图片

 

3、pytorch优化器

3.1、SGD

pytorch学习笔记:优化器_第33张图片

 3.2、其他优化器

pytorch学习笔记:优化器_第34张图片

 

pytorch学习笔记:优化器_第35张图片

pytorch学习笔记:优化器_第36张图片 

pytorch学习笔记:优化器_第37张图片

 

你可能感兴趣的:(pytorch学习笔记)