损失函数是用来评价模型的预测值和真实值不一样的程度。损失函数越好,通常模型的性能也越好。
损失函数分为经验风险损失函数和结构风险损失函数:
本质上是一种对数似然函数,可用于二分类和多分类任务中。
(1)二分类问题中的loss函数(输入数据是softmax或者sigmoid函数的输出):
l o s s = − 1 n ∑ x [ y ln a + ( 1 − y ) ln ( 1 − a ) ] loss = -\frac{1}{n}\sum_{x}[y\ln a + (1-y) \ln (1-a)] loss=−n1x∑[ylna+(1−y)ln(1−a)]
其中, n n n 表示样本数量; y y y 表示样本标签,如负标签为0,正标签为1; a a a 表示模型输出(经过softmax或sigmoid),可以简单看做是预测为正样本的概率值,如0.71。
代码实现如下:
import torch
import torch.nn.functional as F
y = torch.randint(0,2,size=(10,1)).to(torch.float)
p = torch.rand((10,1))
print(y, p)
得到的数据集如下:
(tensor([[0.],
[0.],
[0.],
[1.],
[1.],
[1.],
[0.],
[0.],
[1.],
[0.]]),
tensor([[0.3630],
[0.6427],
[0.5496],
[0.5975],
[0.5041],
[0.8786],
[0.8593],
[0.1009],
[0.5567],
[0.0720]]))
def cross_entropy(y:torch.Tensor, p:torch.Tensor):
return -torch.sum( y*torch.log(p) + (1-y)*torch.log(1-p) ) / len(y)
print( cross_entropy(y, p) )
得到的结果为:
tensor(0.6335)
pytorch
自带的二分类交叉熵损失函数也做了计算:F.binary_cross_entropy(p, y.to(torch.float))
得到的结果也是 tensor(0.6335)
。
(2)多分类问题中的 loss 函数(输入数据也是经过softmax 或 sigmoid 函数的输出):
l o s s = − 1 n ∑ i y i ln a i loss = -\frac{1}{n}\sum_{i}y_i\ln a_i loss=−n1i∑yilnai
n n n 表示样本数量,当样本属于类别 i i i 时, y i = 1 y_i = 1 yi=1,否则 y i = 0 y_i = 0 yi=0; a i a_i ai 表示第 y i y_i yi 的预测概率。
代码实现如下:
a
,这里没有做softmax计算,原因后边会解释):a = torch.randn(10, 3, requires_grad=True)
y = torch.randint(3, (10,), dtype=torch.int64)
print(a, y)
得到数据如下:
tensor([[ 1.0158, -0.5252, 0.4688],
[ 1.1140, -0.3159, 2.0729],
[-0.3154, -0.7816, 0.7599],
[-0.6543, -0.5513, -0.4667],
[ 1.2306, 1.7470, -0.1128],
[ 0.7115, 0.6078, -1.3790],
[-0.9446, 0.7456, -0.2455],
[-0.6758, 0.6414, 0.0501],
[ 0.0048, -0.2063, 0.6877],
[ 0.4730, -0.7138, -0.6765]], requires_grad=True) tensor([1, 2, 0, 0, 2, 2, 0, 1, 0, 2])
def cross_entropy(a, y):
a = F.softmax(a)
return -torch.sum(F.one_hot(y) * torch.log(a+0.000001)) / len(y)
cross_entropy(a, y)
得到的结果是tensor(-1.6143, grad_fn=
。
我们首先将标签进行one-hot编码,这样可以直接与分类概率相乘。0.000001只是为了防止log计算出现bug。
我们将softmax计算放到了函数里。
pytorch
提供的API进行计算:F.cross_entropy(a, y)
得到的结果是tensor(-1.6143, grad_fn=
。
python提供的API接口,内部进行了softmax 计算,所以给到该API的输入不要进行softmax计算!!!!