Pytorch二元交叉熵损失函数种类及接口

之前学了很久的Tensorflow,最近也在研究Pytorch,对损失函数的部分做以下的总结。本文只介绍二分类的二元交叉熵损失。这里不考虑batchsize的情况。只聚焦于公式本身和接口。

在人工智能模型应用的层面,根据目标的不同,大致可分为两类:1 回归, 2 分类

一、回归模型

最常用的损失函数:MSE(mean squared error)

公式如下:     

 ym是真实的函数值,是你的模型预测的函数值。M代表 预测的样本个数。

二、 分类模型

首先,需要明确,不同的Pytorch版本 能直接调用的函数不尽相同,在这里 我着重详列1.7.0版本的所有可使用的二元交叉熵损失函数,以及它们的区别。大家如果想进一步学习,网址如下:

torch.nn.functional — PyTorch 1.7.0 documentation

最常用的损失函数:BCE(binary cross entropy)

公式如下:

 yi是真实标签,xi 是 模型预测的概率值。

1、torch.nn.functional.binary_cross_entropy_with_logits(input, target)

这个二元交叉熵损失的会默认对input里面的每个数据进行sigmoid处理。

例子:

input = [0.3923, -0.2236, -0.3195]

target = [0, 1, 0]

loss = 0.7752

下图是我自己计算的结果,下下图是调用的Pytorch官网的函数。

Pytorch二元交叉熵损失函数种类及接口_第1张图片

 Pytorch二元交叉熵损失函数种类及接口_第2张图片

 

先写到这,后面的函数再补充把。

你可能感兴趣的:(pytorch,深度学习,机器学习,tensorflow,人工智能)