tf.clip_by_value()防止梯度爆炸

y_pred = tf.clip_by_value(y_pred, 10e-8, 1.-10e-8)

后接log防止出现log0。

import tensorflow as tf

v=tf.constant(([1.0,2.0,3.0],[4.0,5.0,6.0]))

with tf.Session() as sess:
    print(tf.clip_by_value(v,2.5,4.5).eval())

输出:

[[2.5 2.5 3. ]
 [4.  4.5 4.5]]

你可能感兴趣的:(T型牌坊)