【深度学习】第四章:反向传播-梯度计算-更新参数

四、训练模型:反向传播-梯度计算-更新参数

1、计算图(Computational Graph)
为什么深度网络模型不建议手写呢,因为底层有太多的东西,手写就写到地老天荒了,其中计算图就是一个难点。
pytorch框架中还封装了计算图,计算图的通俗理解就是:只要你有计算,并且设置了requires_grad=True,pytorch就会自动同时帮你生成这个计算过程的计算图。
我们先用个例子来直观的理解一下什么是计算图,也就是计算图的原理和作用:

那在计算p的同时,pytorch就帮我们生成了下面的计算图:【深度学习】第四章:反向传播-梯度计算-更新参数_第1张图片

说明:
(1)上面的计算图是用来描述运算过程的有向无环图,就是我们现在讲的计算图。其中图中的abcxyzmnp都是叫节点Node,是用来存储数据的,比如向量、矩阵、张量等;图中的箭头都叫边Edge,是用来表示运算关系的,比如加减乘除、矩阵乘、三角函数、卷积计算等等各种运算关系。
(2)节点又分叶子节点和非叶子节点。上图中abc三个节点是用方框框住的,它们三个是叶子节点,xyzmnp是用椭圆框住的,它们6个是非叶子节点。叶子节点和非叶子节点可以用.is_leaf属性来查看。
(3)上图数据从abc一路计算到p的过程也就是一个正向传播的过程。所以一次正向传播,就同时生成这么一张数据计算过程的计算图。
(4)要这个计算图干什么?自然是有用了,不然费这个劲儿干嘛,当然是用来反向传播求梯度了。
(5)上图的例子我们还可以写出p的显式表达式:p = sin((0.5*((a+b)*c)^2-1)/2-0.1),可见,p就是关于abc的函数,那我们就可以求p关于abc的偏导数。
当p和abc之间有明确函数关系的时候,我们可以通过高数中的求导公式求偏导。
但是当p和abc之间的函数关系非常非常复杂,以至于无法显性写出来的时候,用求导公式求导是不是就行不通了呀,此时,感谢这个计算图,它让我们的求导变得异常轻松!
记不记得,求导除了用求导公式,是不是还可以用链式求导法。我想求p关于a的偏导数,用链式求导是不是就是:
偏p/偏n * 偏n/偏m * 偏m/偏z * 偏z/偏y * 偏y/偏x * 偏x/偏a
这是不是大大简化了求导的难度,里面每个乘项都非常非常的简单,而且几乎都是不用计算的,因为如果正向传播的时候的中间结果都保存下来,那偏p/偏n、偏n/偏m。。。等就变得异常简单,最后连乘就是了。
(6)可见,即使你再复杂的函数关系,pytorch都是把你的计算步骤给细分到一个个它底层定义的运算颗粒,这样一个复杂的运算就被分解成一堆有序的、有向的、极简的运算颗粒。
(7)可见,什么是反向传播?反向传播就是根据正向传播时保留的中间结果求上面连乘的项!也就是链式求导的过程。知识点都打通了吧。
(8)那我们求导干嘛?假如p就是损失,abc是参数,是不是要反向传播求导,导数就是参数abc点对应的p下降最快的方向!那就顺着这个方向更新abc,是不是p就不断减小了,当损失p减小到我们可以接受的小的程度时,此时的abc是不是就是模型的最佳参数了,就是模型训练完成了。简单的说,求导就是求梯度,有了梯度就可以更新参数,更新了的新参数就是损失函数减小的一组参数,如此往复,就是模型学习样本数据,就是我们训练了模型。当损失函数小到我们可以接受的程度,此时的模型就是最优的模型,此时的模型就可以很好的帮我们人类预测样本了。

上面说的都是计算图的原理和作用,任何方法的实现都得借助代码实现,所以下面演示一下pytorch中和求导相关的概念、方法和属性。【深度学习】第四章:反向传播-梯度计算-更新参数_第2张图片

说明:A处是我们生成叶子节点的代码。在生成叶子节点的时候一定要给张量添加requires_grad=True这个属性,这样的tensor对象才在参与运算的时候生成计算图。计算时生成计算图肯定是消耗时间也消耗算力的,所以只有用户要求计算梯度时,才同时生成计算图,这也是为了提升效率。

2、反向传播
反向传播,从数学角度说,就是根据上面的计算图求叶子节点的导数,而且是用链式求导法求的导数的。从深度学习角度说,就是求模型参数的梯度值。从目的来说,就是寻找损失函数逐渐减小的模型参数。从模型角度来说,就是模型学习的过程。从训练的角度来说,就是模型训练的过程。

那反向传播既然是一个求导的过程,从代码角度来看,那整个计算过程中的所有对象(就是所有的节点)都必须是.requires_grad=True,这样整个计算过程才能一步步向回追溯,直到追溯到叶子节点的导数。如果中间任何一个节点的requires_grad=False了,那就不能backward了。如上图所示我们的所有节点abcxyzmnp都是requires_grad=True,所以我们就可以顺利计算出abc的导数。

那此处就出现了另外一个问题:我如果想求中间节点对叶子节点的导数呢?比如y对a的导数。或者我想求节点p对中间节点的导数呢?比如求p对z或者p对n这些中间节点的导数。
pytorch为了节省内存空间,中间结果的梯度都给丢弃了,如下图A,我们对p进行bachward后,只能查看叶子节点abc的导数,而中间节点xyzmn的导数都无法查看。无法查看不是没有计算,而且计算了但是没有保存计算结果,从内存中都释放了。如果我们想查看各个环节的导数,只需要搭配使用.backward和.retain_grad就可以了。比如如下图B,你只要在添加y.retain_grad()这行代码,你就可以查看,偏p/偏y的导数了。再比如下图C,你只要从z开始backward,你就可以查看偏z/偏叶子节点的导数了。【深度学习】第四章:反向传播-梯度计算-更新参数_第3张图片

如果你不想计算梯度了,毕竟计算梯度还得生成计算图,如何这个梯度对我没啥用的话,就是浪费资源,此时你可以用with torch.no_grad()语句或者用.detach()方法。【深度学习】第四章:反向传播-梯度计算-更新参数_第4张图片

如果你想求一个函数的微分,这个函数的显式公式你也可以写出来,那是不是还得手动把这个函数切分成一个个计算颗粒,再backward求微分啊?不用,pytorch提供了torch.autograd.grad()函数,可以直接求微分:

用pytorch求导基本的知识点就是这么多就足够用了。

3、梯度下降-更新参数
上面讲了如何求导。我们要清楚,这个求导是在求损失函数对模型参数的导数这个环节的。所以此处承接上篇文章的损失函数部分:【深度学习】第四章:反向传播-梯度计算-更新参数_第5张图片

这是我们前面讲到的损失函数。当时我们就向前传播了鸢尾花数据集的前2条样本。这2条样本从我们自建的架构的输入层喂入,经过2个隐藏层,隐藏层后面都跟一个relu激活层,然后到输出层,再到softmax层,输出就是上图的B预测结果。我们这两条样本的真实标签都是1,而且我们是一个三分类任务,所以把预测结果B和真实标签带入交叉熵公式,就计算出本次小批次的总损失:loss = -(logp1+logp2)。这就是我们的损失函数,而这个损失函数的计算过程也被计算图全程记录了。所以我们可以loss.backward(),就可以求出loss对模型所有参数w和b的偏导数grad(w)、grad(b)。

此时,我们搭建模型、让模型帮助我们进行预测的问题,就转化成:寻找loss函数的最小值。
为什么?因为只有当loss最小,就说明模型把这些样本都预测对了呀。比如上图的loss如果约等于0,那就相当于第一条样本的第一个类概率值趋近于1,并且第二条样本的第一个类概率值也趋近于1。反过来也可以印证:如果两条样本的第一个数都接近1,说明这两条样本的预测结果都是类别0,也同时说明此时loss趋近0。

所以,此时的预测问题就转化成了如何找到"loss函数最小值对应的那组参数"。这个问题在数学上就叫优化问题。

而这个优化问题最常用的解法就是:梯度下降优化算法,即从一个随机点出发,一步步逼近最优解。 如何从一个随机点出发?将我们的模型的初始参数随机化初始设置即可呀,就这么简单。
那又如何一步步逼近?就是通过:w = w - lr*grad。其中lr表示学习率learning rate, grad就是我们前面求的导数,也叫梯度。为什么?
因为loss是一个复杂的、不能显性写出表达式的、关于参数w和b的一个函数。但是我们通过计算图,已经把从叶子节点(w、b、样本的特征)到loss的整个计算过程都记录下来了。
我们计算loss关于w、b的偏导,从几何角度看,就是在loss函数的(w,b)点,找到了loss的切面,而沿着这个切面的反方向就是loss下降最快的方向,所以当(w,b)沿切面反方向走到(w',b')时,loss就变小了一点。其中w' = w - lr*grad(w), 同理b' = b - lr*grad(b)。
那我们重复上面的步骤很多次,直到走到loss接近0时,此时的(w,b)就是使loss最小的一组参数。

上面的过程就是梯度下降算法的原理和数学过程。
下面我用一个小例子展示一下梯度下降算法:【深度学习】第四章:反向传播-梯度计算-更新参数_第6张图片

上图就是一个梯度下降--更新参数的过程。也是一个模型训练的过程。
当然我们DNN构建的模型,loss不像上图的y那么简单,w也不像上图的x那么少。但是原理都是一样的。如果用DNN展示就不太好展示,用一元二次函数比较好展示,所以就举了上面的例子。
至此,我们从整理数据-搭建模型-正向传播-求损失函数-反向传播求梯度-参数迭代,就完成了一个模型的训练全过程。再往后就是模型调优以及保存模型。

4、用鸢尾花数据集的展示完整的训练过程

待续。。。。

你可能感兴趣的:(深度学习,人工智能)