损失函数与反向传播

损失函数定义与作用

损失函数(loss function)在深度学习领域是用来计算搭建模型预测的输出值和真实值之间的误差。
1.损失函数越小越好
2.计算实际输出与目标之间的差距
3.为更新输出提供依据(反向传播)

常见的损失函数

回归常见的损失函数有:均方差(Mean Squared Error,MSE)、平均绝对误差(Mean Absolute Error Loss,MAE)、Huber Loss是一种将MSE与MAE结合起来,取两者优点的损失函数,也被称作Smooth Mean Absolute Error Loss 、分位数损失(Quantile Loss)损失函数。

分类常见的损失函数有:交叉熵损失(Cross Entropy Loss)、合页损失(Hinge Loss)、0/1损失函数、指数损失、对数损失/对数似然损失(Log-likelihood Loss)

平均绝对误差损失(Mean Absolute Error Loss)

损失函数与反向传播_第1张图片
例如 x = [1,2,3],y = [1,2,5]
如果reduction = mean ,结果就是 ((1-1)+(2-2)+(5-3))/3 = 2/3
如果reduction = sum ,结果就是((1-1)+(2-2)+(5-3)) = 2
代码示例:

import torch
from torch.nn import L1Loss

input  = torch.tensor([1,2,3],dtype=float)
target = torch.tensor([1,2,5],dtype=float)
input = torch.reshape(input,(1,1,1,3))
target = torch.reshape(target,(1,1,1,3))
loss = L1Loss()
output = loss(input,target)
print(output)

MSELoss

损失函数与反向传播_第2张图片
和平均绝对误差损失(Mean Absolute Error Loss)差不多,但是加了一个平方

MAE与MSE的区别:

MSE比MAE能够更快收敛:当使用梯度下降算法时,MSE损失的梯度为,而MAE损失的梯度为正负1。所以。MSE的梯度会随着误差大小发生变化,而MAE的梯度一直保持为1,这不利于模型的训练
MAE对异常点更加鲁棒:从损失函数上看,MSE对误差平方化,使得异常点的误差过大;从两个损失函数的假设上看,MSE假设了误差服从高斯分布,MAE假设了误差服从拉普拉斯分布,拉普拉斯分布本身对于异常点更加鲁棒

交叉熵损失函数 (Cross-entropy loss function)

这个有点复杂,没看懂啊
在看参考文章时发现一篇文章很好,在这里,交叉熵损失函数(cross-entropy loss function)
还有很多损失函数,学基础我先了解这些,还有很多,在这里

使用神经网络利用损失函数计算差值

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10("dataset2",train=False,transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset,batch_size=1)

class Test(nn.Module):
    def __init__(self):
        super(Test, self).__init__()
        self.conv1 = Conv2d(3,32,5,padding=2)
        self.maxpool1 = MaxPool2d(kernel_size=2)
        self.conv2 = Conv2d(32,32,5,padding=2)
        self.maxpool2 = MaxPool2d(2)
        self.conv3 = Conv2d(32,64,5,padding=2)
        self.maxpool3 = MaxPool2d(kernel_size=2)
        self.flatten = Flatten()
        self.linear1 = Linear(1024,64)
        self.linear2 = Linear(64,10)

    def forward(self,x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.conv3(x)
        x = self.maxpool3(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.linear2(x)
        return x;
loss = CrossEntropyLoss()
test = Test()
for data in dataloader:
    imgs,target = data
    output = test(imgs)
    result_loss = loss(output,target)
    result_loss.backward()
    print(result_loss)

输出:

tensor(2.2068, grad_fn=<NllLossBackward0>)
tensor(2.2500, grad_fn=<NllLossBackward0>)
tensor(2.2562, grad_fn=<NllLossBackward0>)
tensor(2.4017, grad_fn=<NllLossBackward0>)
tensor(2.4001, grad_fn=<NllLossBackward0>)
tensor(2.3810, grad_fn=<NllLossBackward0>)
tensor(2.3225, grad_fn=<NllLossBackward0>)
tensor(2.3851, grad_fn=<NllLossBackward0>)
tensor(2.2274, grad_fn=<NllLossBackward0>)
tensor(2.3469, grad_fn=<NllLossBackward0>)
tensor(2.4058, grad_fn=<NllLossBackward0>)
tensor(2.1119, grad_fn=<NllLossBackward0>)
tensor(2.3562, grad_fn=<NllLossBackward0>)
tensor(2.2470, grad_fn=<NllLossBackward0>)
tensor(2.1172, grad_fn=<NllLossBackward0>)
tensor(2.2596, grad_fn=<NllLossBackward0>)
tensor(2.3484, grad_fn=<NllLossBackward0>)
tensor(2.2497, grad_fn=<NllLossBackward0>)
tensor(2.2346, grad_fn=<NllLossBackward0>)
tensor(2.3989, grad_fn=<NllLossBackward0>)
tensor(2.2627, grad_fn=<NllLossBackward0>)
tensor(2.4221, grad_fn=<NllLossBackward0>)
tensor(2.3888, grad_fn=<NllLossBackward0>)
tensor(2.1550, grad_fn=<NllLossBackward0>)
tensor(2.3529, grad_fn=<NllLossBackward0>)
tensor(2.3224, grad_fn=<NllLossBackward0>)
tensor(2.3853, grad_fn=<NllLossBackward0>)
tensor(2.3900, grad_fn=<NllLossBackward0>)
tensor(2.1328, grad_fn=<NllLossBackward0>)
tensor(2.3807, grad_fn=<NllLossBackward0>)
tensor(2.3733, grad_fn=<NllLossBackward0>)
tensor(2.3658, grad_fn=<NllLossBackward0>)
tensor(2.3882, grad_fn=<NllLossBackward0>)
......

你可能感兴趣的:(PyTorch,pytorch,深度学习,python)