多层神经网络,从零开始——(七)、损失函数

常用的损失函数对于回归问题有均方误差,对于分类问题一般使用的是交叉熵。通常为了提高泛化性能,还要使用正则化,此时损失函数的参数还包含了权值信息。这里正则化不在损失函数里考虑(当然可以将网络结构输入损失函数从而直接在损失函数里面处理正则化),如果需要使用正则化,直接修改 NNStructure 模块里计算误差对权值的导数函数即可。

一、损失函数基类

module mod_BaseLossFunction
implicit none
    
!-------------------
! 抽象类:损失函数 |
!-------------------
type, abstract, public :: BaseLossFunction

!||||||||||||    
contains   !|
!||||||||||||

    !* 损失函数
    procedure(abs_loss), deferred, public :: loss 
 
    !* 损失函数对最后一层激活函数自变量的导数
    !* 定义见PDF文档
    procedure(abs_d_loss), deferred, public :: d_loss
    
    procedure(abs_print_msg), deferred, public :: print_msg  

end type BaseLossFunction
!===================
    

!-------------------
! 抽象类:函数接口 |
!-------------------    
abstract interface   

    !* 损失函数
    !* 该函数暂时未用到
    subroutine abs_loss( this, t, y, ans )
    use mod_Precision
    import :: BaseLossFunction
    implicit none
        class(BaseLossFunction), intent(inout) :: this
        !* t 是目标输出向量,y 是网络预测向量
        real(PRECISION), dimension(:), intent(in) :: t
        real(PRECISION), dimension(:), intent(in) :: y
        real(PRECISION), intent(inout) :: ans

    end subroutine
    !====
    
    !* 损失函数对最后一层激活函数自变量的导数
    !* 返回对网络预测向量的导数
    subroutine abs_d_loss( this, t, r, z, act_fun, dloss )
    use mod_Precision
    use mod_BaseActivationFunction
    import :: BaseLossFunction
    implicit none
        class(BaseLossFunction), intent(inout) :: this
        !* t 是目标输出向量,
        !* r 是最后一层激活函数的自变量,
        !* z 是网络预测向量
        !* act_fun 是最后一层的激活函数,
        !* dloss 是损失函数对 r 的导数
        real(PRECISION), dimension(:), intent(in) :: t
        real(PRECISION), dimension(:), intent(in) :: r
        real(PRECISION), dimension(:), intent(in) :: z
        class(BaseActivationFunction), pointer, intent(in) :: act_fun
        real(PRECISION), dimension(:), intent(inout) :: dloss

    end subroutine
    !==== 
    
    
    !* 输出信息
    subroutine abs_print_msg( this )
    import :: BaseLossFunction
    implicit none
        class(BaseLossFunction), intent(inout) :: this

    end subroutine
    !====

end interface
!===================
    
end module

二、均方误差

module mod_MeanSquareError
use mod_Precision
use mod_BaseLossFunction
implicit none    

!-------------------
! 工作类:损失函数 |
!-------------------
type, extends(BaseLossFunction), public :: MeanSquareError
    !* 继承自BaseLossFunction并实现其接口

!||||||||||||    
contains   !|
!||||||||||||

    procedure, public :: loss  => m_fun_MeanSquareError
    procedure, public :: d_loss => m_df_MeanSquareError
    
    procedure, public :: print_msg => m_print_msg

end type MeanSquareError
!===================

    !-------------------------
    private :: m_fun_MeanSquareError
    private :: m_df_MeanSquareError
    private :: m_print_msg
    !-------------------------
    
!||||||||||||    
contains   !|
!||||||||||||

    !* MeanSquareError函数
    subroutine m_fun_MeanSquareError( this, t, y, ans )
    implicit none
        class(MeanSquareError), intent(inout) :: this
        !* t 是目标输出向量,对于分类问题,
        !* 它是one-hot编码的向量
        !* y 是网络预测向量
        real(PRECISION), dimension(:), intent(in) :: t
        real(PRECISION), dimension(:), intent(in) :: y
        real(PRECISION), intent(inout) :: ans
    
        ans = 0.5 * DOT_PRODUCT(y - t, y - t)
    
        return
    end subroutine
    !====
    
    !* 损失函数对最后一层激活函数自变量的导数
    !* 返回对网络预测向量的导数
    subroutine m_df_MeanSquareError( this, t, r, z, act_fun, dloss )
    use mod_BaseActivationFunction
    implicit none
        class(MeanSquareError), intent(inout) :: this
        !* t 是目标输出向量,
        !* r 是最后一层激活函数的自变量,
        !* z 是网络预测向量
        !* act_fun 是最后一层的激活函数,
        !* dloss 是损失函数对 r 的导数
        real(PRECISION), dimension(:), intent(in) :: t
        real(PRECISION), dimension(:), intent(in) :: r
        real(PRECISION), dimension(:), intent(in) :: z
        class(BaseActivationFunction), pointer, intent(in) :: act_fun
        real(PRECISION), dimension(:), intent(inout) :: dloss

        real(PRECISION), dimension(:), allocatable :: df_to_dr
        
        allocate( df_to_dr, SOURCE=r )
        
        !* df_to_dr 为 f'(r)
        call act_fun % df_vect( r, df_to_dr )
        
        dloss = (z - t) * df_to_dr
        
        deallocate( df_to_dr )
        
        return
    end subroutine
    !==== 
    
    !* 输出信息
    subroutine m_print_msg( this )
    implicit none
        class(MeanSquareError), intent(inout) :: this

        write(*, *) "Mean Square Error Function."
        
        return
    end subroutine
    !====

end module

三、交叉熵

module mod_CrossEntropy
use mod_Precision
use mod_BaseLossFunction
implicit none    

!-------------------
! 工作类:损失函数 |
!-------------------
type, extends(BaseLossFunction), public :: CrossEntropyWithSoftmax
    !* 继承自BaseLossFunction并实现其接口

!||||||||||||    
contains   !|
!||||||||||||

    procedure, public :: loss  => m_fun_CrossEntropy
    procedure, public :: d_loss => m_df_CrossEntropy
    
    procedure, public :: print_msg => m_print_msg

end type CrossEntropyWithSoftmax
!===================

    !-------------------------
    private :: m_fun_CrossEntropy
    private :: m_df_CrossEntropy
    private :: m_print_msg
    !-------------------------
    
!||||||||||||    
contains   !|
!||||||||||||

    !* CrossEntropy函数
    subroutine m_fun_CrossEntropy( this, t, y, ans )
    implicit none
        class(CrossEntropyWithSoftmax), intent(inout) :: this
        !* t 是目标输出向量,对于分类问题,
        !* 它是one-hot编码的向量
        !* y 是网络预测向量
        real(PRECISION), dimension(:), intent(in) :: t
        real(PRECISION), dimension(:), intent(in) :: y
        real(PRECISION), intent(inout) :: ans
    
        ans = -DOT_PRODUCT(t, LOG(y))
    
        return
    end subroutine
    !====
    
    !* CrossEntropy损失函数对最后一层激活函数自变量的导数
    !* 返回对网络预测向量的导数
    subroutine m_df_CrossEntropy( this, t, r, z, act_fun, dloss )
    use mod_BaseActivationFunction
    implicit none
        class(CrossEntropyWithSoftmax), intent(inout) :: this
        !* t 是目标输出向量,
        !* r 是最后一层激活函数的自变量,
        !* z 是网络预测向量
        !* act_fun 是最后一层的激活函数,
        !* dloss 是损失函数对 r 的导数
        real(PRECISION), dimension(:), intent(in) :: t
        real(PRECISION), dimension(:), intent(in) :: r
        real(PRECISION), dimension(:), intent(in) :: z
        class(BaseActivationFunction), pointer, intent(in) :: act_fun
        real(PRECISION), dimension(:), intent(inout) :: dloss
        
        dloss = (z - t) 
        
        return
    end subroutine
    !====
    

    !* 输出信息
    subroutine m_print_msg( this )
    implicit none
        class(CrossEntropyWithSoftmax), intent(inout) :: this

        write(*, *) "Cross Entropy Function."
        
        return
    end subroutine
    !====
    
    
end module

附录

多层神经网络,从零开始——(一)、Fortran读取MNIST数据集
多层神经网络,从零开始——(二)、Fortran随机生成“双月”分类问题数据
多层神经网络,从零开始——(三)、BP神经网络公式的详细推导
多层神经网络,从零开始——(四)、多层BP神经网络的矩阵形式
多层神经网络,从零开始——(五)、定义数据结构
多层神经网络,从零开始——(六)、激活函数
多层神经网络,从零开始——(七)、损失函数
多层神经网络,从零开始——(八)、分类问题中为什么使用交叉熵作为损失函数
多层神经网络,从零开始——(九)、优化函数
多层神经网络,从零开始——(十)、参数初始化
多层神经网络,从零开始——(十一)、实现训练类
多层神经网络,从零开始——(十二)、实现算例类
多层神经网络,从零开始——(十三)、关于并行计算的简单探讨

你可能感兴趣的:(多层神经网络,从零开始——(七)、损失函数)