【PyTorch深度学习实践】学习笔记 第四节 反向传播

课程链接PyTorch深度学习实践

  • 开始正题前,先做一个知识的补充,关于python的张量——tensor。(不要小看这些细节,有模糊的地方一定要及时搞清楚┗|`O′|┛ 嗷,以防在后面越来越多的使用中出现迷糊。)

开始正题!

案例:根据学习时长 推断成绩
如下图,前三组数据是我们一直的学习时间x与分数y的关系,我们需要推断出x=4的时候y是多少。
【PyTorch深度学习实践】学习笔记 第四节 反向传播_第1张图片这里为了方便效果演示,数据凑的很好,是一个简单的线性映射关系。但在实际中,数据会有很多偏差,我们不容易知道其映射关系,这就是为什么要使用深度学习。

import torch
import matplotlib.pyplot as plt

# prepare the training set
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

w = torch.Tensor([1.0])  # w的初值为1.0 使用的是Tensor创建的
w.requires_grad = True  # w是个单精度浮点型张量,由grad和data组成,需要计算grad梯度则置为true

def forward(x):
    return x * w  # w是一个Tensor!

def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) ** 2

print("predict (before training)", 4, forward(4).item())
epoch_list = []
loss_list = []
for epoch in range(100):
    for x, y in zip(x_data, y_data):
        l = loss(x, y)  # l是一个张量,每调用一次loss函数则会构建一次计算图(通过forward)
        l.backward()  # backward,compute grad for Tensor whose requires_grad set to True
        print('\tgrad:', x, y, w.grad.item())
        w.data = w.data - 0.01 * w.grad.data  # 权重更新时,需要用到标量,注意grad也是一个tensor
        w.grad.data.zero_()  # after update, remember set the grad to zero
    epoch_list.append(epoch)
    loss_list.append(l.item())
    print('progress:', epoch, l.item())  # 取出loss使用l.item,不要直接使用l(l是tensor会构建计算图)
print("predict (after training)", 4, forward(4).item())
# 绘制函数
plt.plot(epoch_list, loss_list)
plt.ylabel('Loss')
plt.xlabel('w')
plt.show()

上述代码有许多细节。
1、w是张量tensor类型,tensor中包含data和grad,data和grad也是tensor,即其grad梯度里也有data。(好奇的可以自己debug看一下数据类型)。grad初始为None,调用l.backward()方法后w.grad为tensor,故更新w.data时使用的是w.grad.data!即一定要注意这里的使用

 w.data = w.data - 0.01 * w.grad.data  # 权重更新时,需要用到标量,注意grad也是一个tensor
        w.grad.data.zero_()  # after update, remember set the grad to zero

2、因为w是tensor,forward函数的返回值也是tensor,loss函数的返回值也是tensor。
3、如果w需要计算梯度,那构建的计算图中,跟w相关的tensor都默认需要计算梯度。
3、本算法中反向传播主要体现在,l.backward()! l 是一个张量,每调用一次loss函数则会构建一次计算图(通过forward),然后再通过l.backward()自动将 l 的计算图中require_grad =true 的tensor 计算出偏导存入其grad中。这里w定义为require_grad=true的tensor了。对比第三节 梯度下降课程,该节课省去了之前的def grad函数了,不用再手动写函数计算梯度了。调用backward()直接计算~ 调用该方法后w.grad由None更新为tensor类型,且w.grad.data的值用于后续w.data的更新。、

  • 前面之所以要讲到Tensor是因为涉及到计算图和auto grad的理解。

Autograd自动求导

autograd包提供Tensor所有操作的自动求导方法。
torch.Tensor是这个包里面最重要的类。如果设置了requires_grad为True,那么它开始追踪所有在它上面的操作。当你完成了前向传播的运算——构建完计算图,可以调用backward(),会自动计算所有的梯度。然后这个tensor的梯度会被自动累积到grad属性上。
pytorch框架干的最厉害的事就是帮我们把反向传播全部计算好了。自动求导机制就是在最后一层的输出进行backward()计算反向传播,其他的中间节点,叶节点的tensor的grad会自动求导!
【PyTorch深度学习实践】学习笔记 第四节 反向传播_第2张图片

作业:

画出二次模型y=w1x²+w2x+b,损失函数loss=(ŷ-y)²的计算图,并且手动推导反向传播的过程,最后用pytorch的代码实现。
答:
构建和推导的过程
【PyTorch深度学习实践】学习笔记 第四节 反向传播_第3张图片

永远四步:

  1. 准备数据集
  2. 设计模块
  3. 构建损失函数和优化函数
  4. 循环训练

定义三个tensor,require_grad=true
定义函数 forward 、loss
训练循环,在每个样本中都调用loss计算图,更新权值

import numpy as np
import matplotlib.pyplot as plt
import torch

x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]

w1 = torch.Tensor([1.0])#初始权值
w1.requires_grad = True#计算梯度,默认是不计算的
w2 = torch.Tensor([1.0])
w2.requires_grad = True
b = torch.Tensor([1.0])
b.requires_grad = True

def forward(x):
    return w1 * x**2 + w2 * x + b

def loss(x,y):#构建计算图
    y_pred = forward(x)
    return (y_pred-y) **2

print('Predict (befortraining)',4,forward(4))

for epoch in range(100):
    l = loss(1, 2)#为了在for循环之前定义l,以便之后的输出,无实际意义
    for x,y in zip(x_data,y_data):
        l = loss(x, y)
        l.backward()
        print('\tgrad:',x,y,w1.grad.item(),w2.grad.item(),b.grad.item()) #调试用的
        # 更新权值
        w1.data = w1.data - 0.01*w1.grad.data #注意这里的grad是一个tensor,所以要取他的data
        w2.data = w2.data - 0.01 * w2.grad.data
        b.data = b.data - 0.01 * b.grad.data
        w1.grad.data.zero_() #释放之前计算的梯度
        w2.grad.data.zero_()
        b.grad.data.zero_()
    print('Epoch:',epoch,l.item())

print('Predict(after training)',4,forward(4).item())
print('final w1, w2, b:', w1.data, w2.data, b.data)

结果

Predict (befortraining) 4 tensor([21.], grad_fn=)
Predict(after training) 4 8.544171333312988

by 小李

如果你坚持到这里了,请一定不要停,山顶的景色更迷人!好戏还在后面呢。加油!

你可能感兴趣的:(深度学习pytorch,pytorch,深度学习,python)