在使用tf.nn.weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight)时,如何传递参数logits?

在使用tf.nn.weighted_cross_entropy_with_logits_v2(labels, logits, pos_weight)时,如何传递参数logits?

如果网络的输出是preds,那么不能直接将preds传递给logits,而应该使logits=K.log(preds/(1-preds)),其中K来自于import tensorflow.keras.backend as K

其原因在于该函数的公式表达式为:labels * -log(sigmoid(logits)) * pos_weight + (1 - labels) * -log(1 - sigmoid(logits))

而非交叉熵公式:labels * -log(preds) * pos_weight + (1 - labels) * -log(1 - preds)

根据https://www.jianshu.com/p/31c7fe00d9de的推导,preds = sigmoid(log(preds/(1-preds)))

因此,在该函数的logits参数传递时,需要传入logits=K.log(preds/(1-preds))才能得到交叉熵的公式:labels * -log(preds) * pos_weight + (1 - labels) * -log(1 - preds)

 

你可能感兴趣的:(tensorflow,python,深度学习,tensorflow)