pytorch中的cross_entropy函数

        cross_entropy函数是pytorch中计算交叉熵的函数。根据源码分析,输入主要包括两部分,一个是input,是维度为(batch_size,class)的矩阵,class表示分类的数量,这个就表示模型输出的预测结果;另一个是target,是维度为(batch_size)的一维向量,表示每个样本的真实值。输出是交叉熵的值。nn中的CrossEntropyLoss类与此函数的作用相同。


计算过程

        交叉熵是常见的损失函数,之前的文章中已经详细介绍了交叉熵的公式由来(交叉熵详解),公式如下:

L=-[y*log\hat{y}+(1-y)*log(1-\hat{y})]

        如果用在多分类问题中当做损失函数的话,一般会这样写:

L=-\sum_{i=1}^{n}y\cdot log_{2}\hat{y}

        其中y是真实分类,是一个标签值;\hat{y}是模型预测结果,包含了属于每种标签的概率(此时这几个概率相加还不等于1)。在上面说了函数的输入分别是input和target,那么y就对应target这个向量,\hat{y}就对应input这个矩阵。但是现在就出现了两个问题:

         1、target就是一个标签值,无法与input直接进行运算,那么我们就对这个target先进行one-hot编码,使二者的维度相同。 比如target的值是[3],共有5个类,那么转换为one-hot编码之后就是:[0,0,0,1,0]。

        2、input中的预测值不能直接代表概率,而且这几个值相加不为1,这时就进行一个softmax操作,让模型的输出值满足上面两个条件,对输出结果的归一化公式如下:

\hat{y}=P(\hat{y}=i|x)=\frac{e^{input_{[i]}}}{\sum_{j=1}^{n}e^{input_{[j]}}}

        将\hat{y}带入到上面的损失函数公式中进行推导:

L=-\sum_{i=1}^{n}y\cdot log_{2}\hat{y}=-\sum_{i=1}^{n}y\cdot \frac{e^{input_{[i]}}}{\sum_{j=1}^{n}e^{input_{[j]}}}=-\sum_{i=1}^{n}y(input[i]-log_{2}\sum_{j=1}^{n}e^{input_{[j]}})

        在第一点中我们已经将y转换为了[0,0,0,1,0]这样的编码,可见式子中只有target那一项的损失值需要计算,其他与0相乘就都消掉了,所以式子中最外层那个连加号就可以去掉了。最后式子就简化为:

L=-input[target]+log_{2}\sum_{j=1}^{n}e^{input_{[j]}}


Python实现

        现在已经清楚了cross_entropy这个函数的运算过程,现在用Python来模拟一下,首先构造一个随机的input和target:

output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
print("output:")
print(output)
print("label:")
print(label)

        输出结果如下:

        然后模拟函数的运算过程:

first = -input[0][target[0]]
second = 0
res=0
for j in range(5):
    second += math.exp(input[0][j])
res = (first + math.log(second))
print("自己的计算结果:")
print(res)

         输出结果为:

         然后分别调用cross_entropy函数和CrossEntropyLoss类计算loss的结果:

criterion = nn.CrossEntropyLoss()
loss = criterion(input, target)
loss2 = torch.nn.functional.cross_entropy(input=input,target=target)
print("cross_entropy函数计算loss的结果:")
print(loss)
print("CrossEntropyLoss类计算loss的结果:")
print(loss2)

        输出结果为:

你可能感兴趣的:(——机器学习——,pytorch,深度学习,机器学习)