[PyTorch][chapter 8][李宏毅深度学习][Back propagation]

前言:

              反向传播算法(英:Backpropagation algorithm,简称:BP算法)是一种监督学习算法,常被用来训练多层感知机。 它用于计算梯度计算中,降低误差。

      

目录:

  1.     链式法则
  2.     模型简介(Model)
  3.     损失函数,梯度
  4.     手写例子
  5.     min-batch

一  链式法则

      链式法则是反向传播算法里面的核心。

     case1: y=g(x),z=h(y), x,y,z 都是scalar

                       

                     \frac{dz }{dx}=\frac{dz }{dy}\frac{dy }{dx}        

      case2:  x=g(s),y=h(s),z=k(x,y),s,x,y,z 都是scalar

                   [PyTorch][chapter 8][李宏毅深度学习][Back propagation]_第1张图片

                       \frac{dz}{ds}=\frac{dz}{dy}\frac{dy}{ds}+\frac{dz}{dx}\frac{dx}{ds}

      case3:   x,y,z 都是向量vector

                   x\rightarrow y\rightarrow z

                    \frac{dz }{dx}=\frac{dz }{dy}\frac{dy }{dx}


二  模型(Model)

以常用的网络模型DNN 为例:

[PyTorch][chapter 8][李宏毅深度学习][Back propagation]_第2张图片

[PyTorch][chapter 8][李宏毅深度学习][Back propagation]_第3张图片

 激活函数为 \sigma

 总的层数为 L


三    损失函数,梯度

       3.1 损失函数

           J(w,b)=||a^{L}-y||_2^{2}

       3.2 梯度更新

               梯度计算分为两步:

   Forward pass, Backward pass

         a Forward pass

               假设 \delta^{l}=\frac{\partial J}{\partial z^l}:

            利用微分和迹的关系很容易得到

         [PyTorch][chapter 8][李宏毅深度学习][Back propagation]_第4张图片

          b  Backward pass  

               假设为最后一层L

                 \delta^{L}=(\frac{\partial a^L}{\partial z^L})^T\frac{\partial J}{\partial a^L}

                       =diag(\sigma^{'}(z^{L}))(a^{L}-\hat{y})

                      =(a^{L}-\hat{y})\odot \sigma{'}(z^{L})

            我们用数学归纳法,第L层的\delta^{L}已经求出, 假设第l+1层的\delta^{l+1}已经求出来了,那么我们如何求出第l层的\delta^{l}呢?

                \delta^{l}=\frac{\partial J}{\partial z^{l}}

                    =(\frac{\partial z^{l+1}}{\partial z^{l}})^T\frac{\partial J}{\partial z^{l+1}}

                    =(\frac{\partial z^{l+1}}{\partial a^l}\frac{\partial a^{l}}{\partial z^l})^T \delta^{l+1}

                    =(diag(\sigma^{'}(z^l)(w^{l+1})^T)\delta^{l+1}

                    =(w^{l+1})^T\delta^{t+1}\odot \sigma^{'}(z^l)


四   简单DNN 网络例子

 4.1 说明:

          这里面随机生成5张图形,分别对应手写数字1,2,3,4,5。

简单的了解一下如何快速搭建一个DNN Model, 梯度如何计算,更新的.

 

# -*- coding: utf-8 -*-
"""
Created on Fri Dec 15 17:21:35 2023

@author: chengxf2
"""

import torch 
from torch import nn
from torch import optim


class DNN(nn.Module):
    
    '''
    它是一个序列容器,是nn.Module的子类。 
    `nn.Sequential` 中的层是有顺序的,而且严格按照其顺序执行
    相邻两个层连接必须保证前一个层的输出与后一个层的输入相匹配。
    '''
    def __init__(self):
        
        super(DNN, self).__init__()
        
        self.net = nn.Sequential(
            nn.Linear(in_features=28*28, out_features=500),
            nn.Sigmoid(),
            nn.Linear(in_features=500, out_features=10),
            nn.Sigmoid()
            )

    def forward(self, input):
        
        output = self.net(input)
        
        return output


def train():
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    model = DNN()
    criteon = torch.nn.CrossEntropyLoss(reduction='mean')
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    batch_size= 5
    data = torch.rand((batch_size,28*28))
    epochs = 2
    target = torch.tensor([0,1,2,3,4])
    target = target.to(device)
    
    for epoch in range(epochs):
        
        yHat = model(data)
        loss = criteon(yHat, target)
        loss.backward()
        print("\n loss ",loss)
        
        optimizer.step()
        

if __name__ == "__main__":
    train()
    
    
    

 [PyTorch][chapter 8][李宏毅深度学习][Back propagation]_第5张图片


五  min-batch

  在深度学习训练中,数据集我们通常采用min-batch 方案

[PyTorch][chapter 8][李宏毅深度学习][Back propagation]_第6张图片

    我们采用随机梯度方法,是为了加快运算速度。

但是GPU 可以并行运算,所以可以采用min-batch 方法进行梯度计算。

   使用min-batch 有个限制:

    1: 硬件限制 batch 不能超过硬件大小

    2:    batch 不能太大,否则容易陷入到局部极小值点,采用小的batch 可以有一定的随机性

每次出发点都不一样,一定概率跳过局部极小值点

[PyTorch][chapter 8][李宏毅深度学习][Back propagation]_第7张图片

参考:

7: Backpropagation_哔哩哔哩_bilibili

https://www.cnblogs.com/pinard/p/6422831.html

CSDN

8-1: “Hello world” of deep learning_哔哩哔哩_bilibili

你可能感兴趣的:(深度学习,pytorch,人工智能)