使用随机梯度下降SGD的BP反向传播算法的PyTorch代码实现

Index 目录索引

  • 写在前面
  • PyTorch的 .data()
  • PyTorch的 .item()
  • BP with SGD的PyTorch代码实现
  • 参考文章


写在前面

本文将用一个完整的例子,借助PyTorch代码来实现神经网络的BP反向传播算法,为了承接上篇文章,本文中的例子仍然使用到了SGD随机梯度算法1【这是深度学习数学原理专题系列的第二篇文章】


PyTorch的 .data()

首先来看PyTorch官方文档里面对该函数的介绍,翻译过来就是,该方法的功能是以标准的Python数字的形式来返回这个张量的值,这个方法只能用于只包含一个元素的张量,对于其他的张量,请查看方法tolist(),该操作是不可微分的,即不可求导。
使用随机梯度下降SGD的BP反向传播算法的PyTorch代码实现_第1张图片简单来讲,就是说PyTorch中的.data()可以将变量(Variable)变为tensor,同时将requires_grad设置为Flase2,代码示例如下:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time    : 2021/11/26 12:03
# @Author  : William Baker
# @FileName: demo_torch_data.py
# @Software: PyCharm
# @Blog    : https://blog.csdn.net/weixin_43051346

# refer:  https://www.cnblogs.com/Guang-Jun/p/14771157.html
# w.grad.item()是等价于 w.grad.data的
import torch

# .data()将变量(Variable)变为tensor,将requires_grad设置为Flase,即不会建立计算图,更新时只修改其数值
a = torch.tensor([1.0], requires_grad=True)
print(a.requires_grad)   # True

b = a.data
print(b, b.requires_grad)        # 输出为: tensor([1.])  False
print(a.requires_grad)           # True

print(a.data.item(), type(a.data.item()))    # 1.0 
print(a.item(), type(a.item()))              # 1.0 

# b2 = a.grad.data
# print(b2, b2.requires_grad)      # AttributeError: 'NoneType' object has no attribute 'data'

PyTorch的 .item()

PyTorch中的.item()方法的功能是,以标准的Python数字的形式来返回这个张量的值。这个方法只能用于只包含一个元素的张量。

该方法返回的结果是普通Python数据类型,自然不能调用backward()方法来进行梯度的反向传播。代码示例如下:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time    : 2021/11/26 12:16
# @Author  : William Baker
# @FileName: demo_torch_item.py
# @Software: PyCharm
# @Blog    : https://blog.csdn.net/weixin_43051346

# refer:  https://www.cnblogs.com/Guang-Jun/p/14771157.html
import torch

a = torch.tensor([1.0], requires_grad=True)

c = a.item()
print(c, type(c))      # 输出为:1.0 

BP with SGD的PyTorch代码实现

我们知道,神经网络的正向传播Forward 是计算损失loss,创建新的计算图;而反向传播Backward 是求偏导进行梯度计算,即损失值loss对权重值w求梯度,计算完梯度后会将其存到变量(比如权重w)里面,存完之后计算图就会得到释放。

G r a d i e n t : ∂ l o s s ∂ w Gradient: \frac{\partial loss}{\partial w} Gradient:wloss

铺垫完毕,接下来上代码(还是以线性函数为例),具体的讲解尽在要多详细有多详细的注释中:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time    : 2021/11/24 17:37
# @Author  : William Baker
# @FileName: SGD_torch.py
# @Software: PyCharm
# @Blog    : https://blog.csdn.net/weixin_43051346

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
import torch
import matplotlib.pyplot as plt


x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

#tensor中包含data(w)和grad(loss对w求导)
w = torch.tensor([1.0])     # w的初值为1.0
w.requires_grad = True      # 需要计算梯度
print(w.data)
# 学习率 learning rate
lr = 0.01

def forward(x):
    return x * w     # x和w都是tensor类型,数乘

def loss(x, y):
    # 每调用一次loss函数,就把计算图构建出来了
    y_pred = forward(x)
    return (y_pred - y) ** 2

# print("predict (before training)", 4, forward(4).item())
print('训练前的输入值x:{}, 训练前的预测值:{}\n'.format(4, forward(4).item()))
print("***************************开始训练***************************")    # 训练当达到最优值时,即损失函数梯度grad下降到梯度为0,W的值不再继续迭代

epoch_list = []
cost_list = []
for epoch in range(100):
    for x, y in zip(x_data, y_data):
        # 下面两行,构建计算图的时候,直接使用张量进行运算(但是在权重更新的时候,要使用.data)
        l = loss(x, y)      # l是一个张量tensor,tensor主要是在建立计算图 forward, compute the loss,Forward 前馈是计算损失loss,创建新的计算图
        l.backward()        # backward,compute grad for Tensor whose requires_grad set to True 反向传播过程就会自动计算所需要的梯度
                            # Backward 反馈是计算梯度,计算完梯度后会将其存到变量(比如权重w)里面,存完之后计算图就会得到释放
                            # 每进行一次反向传播,把图释放,准备进行下一次的图
        """
        # print('\tgrad:', x, y, w.grad.item(), w.grad.item().type)    # 报错 AttributeError: 'float' object has no attribute 'type'
        print('\tgrad:', x, y, w.grad.item(), type(w.grad.item()))     # w.grad.item()   # grad:  2.0 4.0 -7.840000152587891 
        # print('\tgrad-:', x, y, w.grad, type(w.grad))                # w.grad          # grad-: 2.0 4.0 tensor([-7.8400]) 
        print('\tgrad-*:', x, y, w.grad, w.grad.type())                # w.grad          # grad-*: 2.0 4.0 tensor([-7.8400]) torch.FloatTensor
        print('\tgrad--:', x, y, w.grad.data, w.grad.data.type())      # w.grad.data     # grad--: 2.0 4.0 tensor([-7.8400]) torch.FloatTensor
        """
        print('\tgrad:', x, y, w.grad.item())                           # grad: 2.0 4.0 -7.840000152587891
        # print('\tgrad--:', x, y, w.data.item(), type(w.data.item()))    # grad--: 2.0 4.0 1.0199999809265137 

        # w -= lr * grad_val        # w = w - lr * gradient(w)   梯度下降的核心所在
        # print(w.data.requires_grad)    # False
        w.data = w.data - lr * w.grad.data       # 权重更新时,需要用到标量,注意grad也是一个tensor   # w.grad.item()是等价于 w.grad.data的,都是不建立计算图
        # print(w.data.requires_grad)    # False

        w.grad.data.zero_()     # after update, remember set the grad to zero     # 把权重里面的梯度数据清0,不然就变成了梯度累加

    epoch_list.append(epoch)
    cost_list.append(l)
    # print('progress:', epoch, l.item())
    print('Progress: Epoch {}, loss:{}'.format(epoch, l.item()))    # 取出loss使用l.item,不要直接使用l(l是tensor会构建计算图)
                                                                          # Progress: Epoch 99, loss:9.094947017729282e-13
    # print('Progress-: Epoch {}, loss:{}'.format(epoch, l.data.item()))  # Progress-: Epoch 99, loss:9.094947017729282e-13

print("***************************训练结束***************************\n")
# print("predict (after training)", 4, forward(4).item())
print('训练后的输入值x:{}, 训练后的预测值:{}'.format(4, forward(4).item()))

# 绘图
plt.plot(epoch_list, cost_list)
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.show()

输出结果如下:


tensor([1.])
训练前的输入值x:4, 训练前的预测值:4.0

***************************开始训练***************************
	grad: 1.0 2.0 -2.0
	grad: 2.0 4.0 -7.840000152587891
	grad: 3.0 6.0 -16.228801727294922
Progress: Epoch 0, loss:7.315943717956543
	grad: 1.0 2.0 -1.478623867034912
	grad: 2.0 4.0 -5.796205520629883
	grad: 3.0 6.0 -11.998146057128906
Progress: Epoch 1, loss:3.9987640380859375
	grad: 1.0 2.0 -1.0931644439697266
	grad: 2.0 4.0 -4.285204887390137
	grad: 3.0 6.0 -8.870372772216797
Progress: Epoch 2, loss:2.1856532096862793
	grad: 1.0 2.0 -0.8081896305084229
	grad: 2.0 4.0 -3.1681032180786133
	grad: 3.0 6.0 -6.557973861694336
Progress: Epoch 3, loss:1.1946394443511963
	grad: 1.0 2.0 -0.5975041389465332
	grad: 2.0 4.0 -2.3422164916992188
	grad: 3.0 6.0 -4.848389625549316
Progress: Epoch 4, loss:0.6529689431190491
	grad: 1.0 2.0 -0.4417421817779541
	grad: 2.0 4.0 -1.7316293716430664
	grad: 3.0 6.0 -3.58447265625
Progress: Epoch 5, loss:0.35690122842788696
	grad: 1.0 2.0 -0.3265852928161621
	grad: 2.0 4.0 -1.2802143096923828
	grad: 3.0 6.0 -2.650045394897461
Progress: Epoch 6, loss:0.195076122879982
	grad: 1.0 2.0 -0.24144840240478516
	grad: 2.0 4.0 -0.9464778900146484
	grad: 3.0 6.0 -1.9592113494873047
Progress: Epoch 7, loss:0.10662525147199631
	grad: 1.0 2.0 -0.17850565910339355
	grad: 2.0 4.0 -0.699742317199707
	grad: 3.0 6.0 -1.4484672546386719
Progress: Epoch 8, loss:0.0582793727517128
	grad: 1.0 2.0 -0.1319713592529297
	grad: 2.0 4.0 -0.5173273086547852
	grad: 3.0 6.0 -1.070866584777832
Progress: Epoch 9, loss:0.03185431286692619
	grad: 1.0 2.0 -0.09756779670715332
	grad: 2.0 4.0 -0.3824653625488281
	grad: 3.0 6.0 -0.7917022705078125
Progress: Epoch 10, loss:0.017410902306437492
	grad: 1.0 2.0 -0.07213282585144043
	grad: 2.0 4.0 -0.2827606201171875
	grad: 3.0 6.0 -0.5853137969970703
Progress: Epoch 11, loss:0.009516451507806778
	grad: 1.0 2.0 -0.053328514099121094
	grad: 2.0 4.0 -0.2090473175048828
	grad: 3.0 6.0 -0.43272972106933594
Progress: Epoch 12, loss:0.005201528314501047
	grad: 1.0 2.0 -0.039426326751708984
	grad: 2.0 4.0 -0.15455150604248047
	grad: 3.0 6.0 -0.3199195861816406
Progress: Epoch 13, loss:0.0028430151287466288
	grad: 1.0 2.0 -0.029148340225219727
	grad: 2.0 4.0 -0.11426162719726562
	grad: 3.0 6.0 -0.23652076721191406
Progress: Epoch 14, loss:0.0015539465239271522
	grad: 1.0 2.0 -0.021549701690673828
	grad: 2.0 4.0 -0.08447456359863281
	grad: 3.0 6.0 -0.17486286163330078
Progress: Epoch 15, loss:0.0008493617060594261
	grad: 1.0 2.0 -0.01593184471130371
	grad: 2.0 4.0 -0.062453269958496094
	grad: 3.0 6.0 -0.12927818298339844
Progress: Epoch 16, loss:0.00046424579340964556
	grad: 1.0 2.0 -0.011778593063354492
	grad: 2.0 4.0 -0.046172142028808594
	grad: 3.0 6.0 -0.09557533264160156
Progress: Epoch 17, loss:0.0002537401160225272
	grad: 1.0 2.0 -0.00870823860168457
	grad: 2.0 4.0 -0.03413581848144531
	grad: 3.0 6.0 -0.07066154479980469
Progress: Epoch 18, loss:0.00013869594840798527
	grad: 1.0 2.0 -0.006437778472900391
	grad: 2.0 4.0 -0.025236129760742188
	grad: 3.0 6.0 -0.052239418029785156
Progress: Epoch 19, loss:7.580435340059921e-05
	grad: 1.0 2.0 -0.004759550094604492
	grad: 2.0 4.0 -0.018657684326171875
	grad: 3.0 6.0 -0.038620948791503906
Progress: Epoch 20, loss:4.143271507928148e-05
	grad: 1.0 2.0 -0.003518819808959961
	grad: 2.0 4.0 -0.0137939453125
	grad: 3.0 6.0 -0.028553009033203125
Progress: Epoch 21, loss:2.264650902361609e-05
	grad: 1.0 2.0 -0.00260162353515625
	grad: 2.0 4.0 -0.010198593139648438
	grad: 3.0 6.0 -0.021108627319335938
Progress: Epoch 22, loss:1.2377059647405986e-05
	grad: 1.0 2.0 -0.0019233226776123047
	grad: 2.0 4.0 -0.0075397491455078125
	grad: 3.0 6.0 -0.0156097412109375
Progress: Epoch 23, loss:6.768445018678904e-06
	grad: 1.0 2.0 -0.0014221668243408203
	grad: 2.0 4.0 -0.0055751800537109375
	grad: 3.0 6.0 -0.011541366577148438
Progress: Epoch 24, loss:3.7000872907810844e-06
	grad: 1.0 2.0 -0.0010514259338378906
	grad: 2.0 4.0 -0.0041217803955078125
	grad: 3.0 6.0 -0.008531570434570312
Progress: Epoch 25, loss:2.021880391112063e-06
	grad: 1.0 2.0 -0.0007772445678710938
	grad: 2.0 4.0 -0.0030469894409179688
	grad: 3.0 6.0 -0.006305694580078125
Progress: Epoch 26, loss:1.1044940038118511e-06
	grad: 1.0 2.0 -0.0005745887756347656
	grad: 2.0 4.0 -0.0022525787353515625
	grad: 3.0 6.0 -0.0046634674072265625
Progress: Epoch 27, loss:6.041091182851233e-07
	grad: 1.0 2.0 -0.0004248619079589844
	grad: 2.0 4.0 -0.0016651153564453125
	grad: 3.0 6.0 -0.003444671630859375
Progress: Epoch 28, loss:3.296045179013163e-07
	grad: 1.0 2.0 -0.0003139972686767578
	grad: 2.0 4.0 -0.0012311935424804688
	grad: 3.0 6.0 -0.0025491714477539062
Progress: Epoch 29, loss:1.805076408345485e-07
	grad: 1.0 2.0 -0.00023221969604492188
	grad: 2.0 4.0 -0.0009107589721679688
	grad: 3.0 6.0 -0.0018854141235351562
Progress: Epoch 30, loss:9.874406714516226e-08
	grad: 1.0 2.0 -0.00017189979553222656
	grad: 2.0 4.0 -0.0006742477416992188
	grad: 3.0 6.0 -0.00139617919921875
Progress: Epoch 31, loss:5.4147676564753056e-08
	grad: 1.0 2.0 -0.0001270771026611328
	grad: 2.0 4.0 -0.0004978179931640625
	grad: 3.0 6.0 -0.00102996826171875
Progress: Epoch 32, loss:2.9467628337442875e-08
	grad: 1.0 2.0 -9.393692016601562e-05
	grad: 2.0 4.0 -0.0003681182861328125
	grad: 3.0 6.0 -0.0007610321044921875
Progress: Epoch 33, loss:1.6088051779661328e-08
	grad: 1.0 2.0 -6.937980651855469e-05
	grad: 2.0 4.0 -0.00027179718017578125
	grad: 3.0 6.0 -0.000560760498046875
Progress: Epoch 34, loss:8.734787115827203e-09
	grad: 1.0 2.0 -5.125999450683594e-05
	grad: 2.0 4.0 -0.00020122528076171875
	grad: 3.0 6.0 -0.0004177093505859375
Progress: Epoch 35, loss:4.8466972657479346e-09
	grad: 1.0 2.0 -3.790855407714844e-05
	grad: 2.0 4.0 -0.000148773193359375
	grad: 3.0 6.0 -0.000308990478515625
Progress: Epoch 36, loss:2.6520865503698587e-09
	grad: 1.0 2.0 -2.8133392333984375e-05
	grad: 2.0 4.0 -0.000110626220703125
	grad: 3.0 6.0 -0.0002288818359375
Progress: Epoch 37, loss:1.4551915228366852e-09
	grad: 1.0 2.0 -2.09808349609375e-05
	grad: 2.0 4.0 -8.20159912109375e-05
	grad: 3.0 6.0 -0.00016880035400390625
Progress: Epoch 38, loss:7.914877642178908e-10
	grad: 1.0 2.0 -1.5497207641601562e-05
	grad: 2.0 4.0 -6.103515625e-05
	grad: 3.0 6.0 -0.000125885009765625
Progress: Epoch 39, loss:4.4019543565809727e-10
	grad: 1.0 2.0 -1.1444091796875e-05
	grad: 2.0 4.0 -4.482269287109375e-05
	grad: 3.0 6.0 -9.1552734375e-05
Progress: Epoch 40, loss:2.3283064365386963e-10
	grad: 1.0 2.0 -8.344650268554688e-06
	grad: 2.0 4.0 -3.24249267578125e-05
	grad: 3.0 6.0 -6.580352783203125e-05
Progress: Epoch 41, loss:1.2028067430946976e-10
	grad: 1.0 2.0 -5.9604644775390625e-06
	grad: 2.0 4.0 -2.288818359375e-05
	grad: 3.0 6.0 -4.57763671875e-05
Progress: Epoch 42, loss:5.820766091346741e-11
	grad: 1.0 2.0 -4.291534423828125e-06
	grad: 2.0 4.0 -1.71661376953125e-05
	grad: 3.0 6.0 -3.719329833984375e-05
Progress: Epoch 43, loss:3.842615114990622e-11
	grad: 1.0 2.0 -3.337860107421875e-06
	grad: 2.0 4.0 -1.33514404296875e-05
	grad: 3.0 6.0 -2.86102294921875e-05
Progress: Epoch 44, loss:2.2737367544323206e-11
	grad: 1.0 2.0 -2.6226043701171875e-06
	grad: 2.0 4.0 -1.049041748046875e-05
	grad: 3.0 6.0 -2.288818359375e-05
Progress: Epoch 45, loss:1.4551915228366852e-11
	grad: 1.0 2.0 -1.9073486328125e-06
	grad: 2.0 4.0 -7.62939453125e-06
	grad: 3.0 6.0 -1.430511474609375e-05
Progress: Epoch 46, loss:5.6843418860808015e-12
	grad: 1.0 2.0 -1.430511474609375e-06
	grad: 2.0 4.0 -5.7220458984375e-06
	grad: 3.0 6.0 -1.1444091796875e-05
Progress: Epoch 47, loss:3.637978807091713e-12
	grad: 1.0 2.0 -1.1920928955078125e-06
	grad: 2.0 4.0 -4.76837158203125e-06
	grad: 3.0 6.0 -1.1444091796875e-05
Progress: Epoch 48, loss:3.637978807091713e-12
	grad: 1.0 2.0 -9.5367431640625e-07
	grad: 2.0 4.0 -3.814697265625e-06
	grad: 3.0 6.0 -8.58306884765625e-06
Progress: Epoch 49, loss:2.0463630789890885e-12
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 50, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 51, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 52, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 53, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 54, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 55, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 56, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 57, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 58, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 59, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 60, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 61, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 62, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 63, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 64, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 65, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 66, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 67, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 68, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 69, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 70, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 71, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 72, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 73, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 74, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 75, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 76, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 77, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 78, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 79, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 80, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 81, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 82, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 83, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 84, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 85, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 86, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 87, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 88, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 89, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 90, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 91, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 92, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 93, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 94, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 95, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 96, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 97, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 98, loss:9.094947017729282e-13
	grad: 1.0 2.0 -7.152557373046875e-07
	grad: 2.0 4.0 -2.86102294921875e-06
	grad: 3.0 6.0 -5.7220458984375e-06
Progress: Epoch 99, loss:9.094947017729282e-13
***************************训练结束***************************

训练后的输入值x:4, 训练后的预测值:7.999998569488525

Process finished with exit code 0

使用随机梯度下降SGD的BP反向传播算法的PyTorch代码实现_第2张图片迭代过程中,Loss值随着Epoch的变化如上图所示,从该图以及程序的输出结果以及可以得出,随着训练逐渐达到最优值,损失函数减低的同时,损失函数对权重的梯度grad也逐渐下降为0,这正是梯度下降到了最低点,即函数梯度值为0的点。如下图二维和三维的图所示,最后来到了红色的点,该点即为要求的最优点3(下面两个图片来自前辈的文章,因为画得已经很好了,懒得再画了,已做引用处理,侵删)
使用随机梯度下降SGD的BP反向传播算法的PyTorch代码实现_第3张图片使用随机梯度下降SGD的BP反向传播算法的PyTorch代码实现_第4张图片


写到这里,差不多本文也就要结束了。如果我的这篇文章帮助到了你,那我也会感到很高兴,一个人能走多远,在于与谁同行


参考文章


  1. 《PyTorch深度学习实践》完结合集 - 04.反向传播
    ↩︎

  2. pytorch中.data()与.item()
    ↩︎

  3. 梯度下降算法原理以及其实现
    ↩︎

你可能感兴趣的:(PyTorch,深度学习,pytorch,人工智能,python)