坑爹啊,tf.nn.softmax_cross_entropy_with_logits坑了我好久

1.tf.nn.softmax_cross_entropy_with_logits作用

这里先给出两个函数:

tf.nn.softmax_cross_entropy_with_logits(_sentinel=None,  # pylint: disable=invalid-name
                                      labels=None, logits=None,
                                      dim=-1, name=None)

tf.nn.spare_softmax_cross_entropy_with_logits(_sentinel=None,  # pylint: disable=invalid-name
                                             labels=None, logits=None,
                                             name=None)

先说明,tf.nn.softmax_cross_entropy_with_logits这个函数的作用是什么:
网上有很多关于这个函数的解析,这里盗用一下人家的图,这是它的计算公式交叉熵损失计算公式
tf.nn.spare_softmax_cross_entropy_with_logits本质与其一致。
区别在于:

  • tf.nn.spare_softmax_cross_entropy_with_logits的参数labels不能是one-hot编码
  • tf.nn.softmax_cross_entropy_with_logits的参数labels可以是one-hot编码

此外,每次使用tf.nn.softmax_cross_entropy_with_logits,tensorflow都会提出这个方法即将被摈弃,我原以为摈弃就摈弃,功能应该是没问题的。这恰恰是我天真的地方。

大概测试了一下,两个函数的功能:

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

label_num = [0,1]
label = [[1,0], [0,1]]                              # [B,N]
logist = [[-0.00161515, 0.0061736 ], [ 0.00118857, 0.00383149]]
label_num_t = tf.convert_to_tensor(label_num)
label_t = tf.to_float(tf.convert_to_tensor(label))
logist_t = tf.convert_to_tensor(logist)
logist_t_softmax = tf.nn.softmax(logist_t)
loss_1 = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits(labels=label_t, logits=logist_t)
)

loss_5 = tf.reduce_mean(
    tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label_num_t, logits=logist_t)
)
print(sess.run(loss_1))
print(sess.run(loss_5))

发现输出为

0.69443786
0.69443786

嗯,一模一样啊,那我就放心用了。
意外发生了。
当类别数量达到100的时候,根据公式,理论上1个样本的一个loss值不会超过-1*log(1/100)=4.60517,而我运行结果却是5.90。就是这个函数的奇奇怪怪的结果,导致我训练过程中loss不减反增。
因此,我果断改用tf.nn.sparse_softmax_cross_entropy_with_logits
下面给出将one-hot编码,改为普通编码的方式:

label = tf.argmax(label, axis=1)  # label.shape从[B,N]->[B]
loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=logit)
        )

到此,就可以愉快地继续训练了~
希望会对大家有帮助,不要继续踩坑。

这里,感谢这位网友的图

https://www.jianshu.com/p/cf235861311b

你可能感兴趣的:(tensorflow,推荐算法)