TensorFlow中的tf.nn.softmax_cross_entropy_with_logits_v2函数详解

一、函数介绍

函数形式:tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits)

  • 需要注意的是,此处是最常见的参数形式,即只有labels和logits参数。为了简单起见,这里不对其他参数进行赘述。
  • 其中,logits为神经网络最后一层的输出labels为对应的标签一般为one-hot编码形式即只对真实标签对应的logits分量进行交叉熵损失的计算

事实上,labels也不一定非得是one-hot编码形式;实际上,如果labels和logits维度大小相同,不管labels是不是one-hot编码,代码都可以正常运行,只是没什么意义罢了。一般情况下,建议大家都用one-hot编码形式的labels进行计算。

二、实例讲解

下面,采用非one-hot编码形式的labels对tf.nn.softmax_cross_entropy_with_logits_v2函数进行过程分析,如果非one-hot编码形式的计算过程都能理解,那么one-hot编码形式的计算过程就更清楚了。

简单来说,该函数内部会首先对logits分量进行softmax处理,之后再按照交叉熵损失的公式进行计算。

# 参考链接:
# https://blog.csdn.net/sdnuwjw/article/details/86086377
# https://blog.csdn.net/Muzi_Water/article/details/81363027
# https://www.cnblogs.com/peixu/p/13201093.html
# https://blog.csdn.net/qq_39208832/article/details/117415229
import math
import torch
import torch.nn as nn
import tensorflow as tf

labels = [[0, 1, 2], [0, 1, 3], [1, 0, 2]]
logits = tf.constant(value=[[0.8, 0.3, 0.4], [0, 0.8, 0.3], [0, 0.7, 0.3]],
                     dtype=tf.float32, shape=[3, 3])

loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits)
# loss_mean = tf.reduce_mean(loss)

with tf.Session() as sess:
    print('labels:',labels)
    print('logits:\n',sess.run(logits))
    print('loss:',sess.run(loss))
    # print('loss_mean:',sess.run(loss_mean))

print("-----------------------------------")
print(nn.Softmax(dim=1)(torch.tensor([[0.8, 0.3, 0.4]])))
print(-math.log(0.2664)-2*math.log(0.2944))
print(nn.Softmax(dim=1)(torch.tensor([[0, 0.8, 0.3]])))
print(-math.log(0.4864)-3*math.log(0.2950))
print(nn.Softmax(dim=1)(torch.tensor([[0, 0.7, 0.3]])))
print(-math.log(0.2292)-2*math.log(0.3093))

TensorFlow中的tf.nn.softmax_cross_entropy_with_logits_v2函数详解_第1张图片

注意:

1.tf.nn.softmax_cross_entropy_with_logits和tf.nn.softmax_cross_entropy_with_logits_v2用法没什么区别,至少对于上面这个例子是这样的。大家可以把上述例子中的函数改为tf.nn.softmax_cross_entropy_with_logits试试,可以看到,输出结果并不会发生什么变化。
2.一般,在用该函数进行交叉熵损失求解时,会进行如下设定:
cross_entropy_loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=labels)),即对一个batch中的样本的损失求均值。事实上,Pytorch中的交叉熵损失计算公式nn.CrossEntropyLoss()默认对batch内的样本损失进行求均值处理;而TensorFlow中的tf.nn.softmax_cross_entropy_with_logits_v2并没有默认对batch内的样本的损失进行均值处理,所以需要手动进行均值处理。

你可能感兴趣的:(深度学习,tensorflow,python,深度学习)