Tensorflow 中网络准确度不变,权重初始化NaN问题

最近刚刚接触深度学习,由于项目涉及到一些移动端开发的问题,也听了一些朋友的建议,最后决定选择tensorflow作为研究深度学习的平台。这两天照着tflearn官网的VGGNet的demo,用tensorflow实现了VGGNet,然而在用17flowers训练集进行训练的时候,发现不管迭代多少次,准确率和loss函数始终维持在相对不变的值,也就是网络不收敛。一开始很懵逼,毕竟是照着官网的demo做的,怎么会出现这种情况?首先想到的办法就是将中间值打出来,比如:

for i in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys, keep_prob: 0.5})

    print("Loss:", sess.run(cross_entropy,feed_dict={xs: batch_xs, ys: batch_ys, keep_prob: 0.5}))
    print(cross_entropy)
    if i % 50 == 0:
        print(compute_accuracy(
            mnist.test.images, mnist.test.labels))

Loss是互熵损失的输出值,但结果显示它的值是Nan。于是追根溯源,又打出了Weights和bias的值,发现一样也是Nan。然后我就去google这个问题,发现其实还是有不少人遇到了这个问题的。在stackoverflow上,其中一个人的解释是这样的:

Actually, it turned out to be something stupid. I'm posting this in case anyone else would run into a similar error.

cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
is actually a horrible way of computing the cross-entropy. In some samples, certain classes could be excluded with certainty after a while, resulting in y_conv=0 for that sample. That's normally not a problem since you're not interested in those, but in the way cross_entropy is written there, it yields 0*log(0) for that particular sample/class. Hence the NaN.

Replacing it with

cross_entropy = -tf.reduce_sum(y_*tf.log(tf.clip_by_value(y_conv,1e-10,1.0)))
solved all my problems.

于是试了一下,果然没有Nan了,但是这个朋友的解释,我似懂非懂,大概意思就是有一些样本通过正向传递输出到最外层的时候输出值变为了0,于是log(0)会导致结果显示为Nan。随后,下面还有一个网友给出解释,说clipping的办法并不是很好,因为反向传播的时候如果达到了阈值,会阻止梯度的改变。所以直接在log函数中添加了一个小常量,为了不让预测值为0。

cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv + 1e-10))

虽然问题解决了,但理解的并不透彻。如果有朋友理解的很透彻,也麻烦在评论里指点一下!非常感谢!

------------------------------------------------------------------------分割线-------------------------------------------------------------------------
上次发现了这个问题之后,导师也非常感兴趣,他很好奇,只是添加这么一个极小值就可以让网络继续训练下去实在是不可思议。但归根结底,原因还是预测值的部分出现了数学上无意义的值。所以导师就想到了在求互熵损失时,由于是对e求指数,有可能会导致运算溢出。这个问题仍待解决


附上stackoverflow中的解答:http://stackoverflow.com/questions/33712178/tensorflow-nan-bug

你可能感兴趣的:(深度学习)