一直在等待,一直会等待 TensorFlow常见API--2

tf.nn.softmax_cross_entropy_with_logits(_sentinel=None,labels=None,logits=None,dim=-1,name=None)

        计算 softmax(logits) 和 softmax(labels) 之间的交叉熵
_sentinel: 用于防止位置参数。内部调用,外界调用API不使用。
labels: 真实数据的类别标签。每个类别维度应该是有效的概率分布。比如[batch_size, num_classes]则每一行都是一个类的数据分布
logits: 没有放缩的log概率值
dim: 类维度。默认为-1,这是最后一个维度
name: 操作名称

tf.nn.softmax_cross_entropy_with_logits_v2(_sentinel=None,labels=None,logits=None,dim=-1,name=None)

        计算 softmax(logits) 和 softmax(labels) 之间的交叉熵
_sentinel: 用于防止位置参数。内部调用,外界调用API不使用。
labels: 真实数据的类别标签。每个类别维度应该是有效的概率分布。比如[batch_size, num_classes]则每一行都是一个类的数据分布
logits: 没有放缩的log概率值
dim: 类维度。默认为-1,这是最后一个维度
name: 操作名称
两者之间的区别:
         softmax_cross_entropy_with_logits进行反向传播时,只对logits进行反向传播,labels保持不变
         tf.nn.softmax_cross_entropy_with_logits_v2同时对logits和labels进行反向传播。如果将labels传入的tensor设置为stop_gradients,和softmax_cross_entropy_with_logits一样
         那么问题来了,一般我们在进行监督学习的时候,labels都是标记好的真值,什么时候会需要改变label?softmax_cross_entropy_with_logits_v2存在的意义是什么?实际上在应用中labels并不一定都是人工手动标注的,有的时候还可能是神经网络生成的,一个实际的例子就是对抗生成网络(GAN)
         区别分析原文:https://blog.csdn.net/u013230189/article/details/82777464?utm_source=copy

tf.nn.sparse_softmax_cross_entropy_with_logits(_sentinel=None,labels=None,logits=None,name=None)

labels: Tensor of shape [d_0, d_1, …, d_{r-1}] ,dtype int32 or int64
logits: Unscaled log probabilities of shape [d_0, d_1, …, d_{r-1}, num_classes] and dtype float16, float32, or float64
        注意:labels是一个数值,这个数值记录着ground truth所在的索引。以[0,0,1,0]为例,这里真值1的索引为2。所以要求labels的输入为数字2(tensor)。一般可以用tf.argmax()来从[0,0,1,0]中取得真值的索引。

import tensorflow as tf
import numpy as np

Truth = np.array([0, 0, 1, 0])
Pred_logits = np.array([3.5, 2.1, 7.89, 4.4])

loss = tf.nn.softmax_cross_entropy_with_logits(labels=Truth,logits=Pred_logits)
loss2 = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Truth,logits=Pred_logits)
loss3 = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(Truth),logits=Pred_logits)

with tf.Session() as sess:
    print(sess.run(loss))
    print(sess.run(loss2))
    print(sess.run(loss3))

        代码用例原文:https://blog.csdn.net/tsyccnh/article/details/81069308?utm_source=copy

你可能感兴趣的:(TensorFlow,API,笔记)