问题:
在使用keras提供的损失函数API时,梯度无法反向传播
代码:
from tensorflow.keras.losses import categorical_crossentropy
def train_generator(x, y, z, eps, dcgan, siamese_model, loss=None):
with tf.GradientTape(persistent=True) as t:
fake_x = dcgan.generator([z, y])
loss_G = -tf.reduce_mean(dcgan.discriminator(fake_x))
preds = aux_model(fake_x)
aux_mean = categorical_crossentropy(y, preds)
aux_loss = tf.reduce_mean(aux_mean)
total_loss = aux_loss + loss_G
gradient_g = t.gradient(total_loss, dcgan.generator.trainable_variables)
dcgan.optimizer_G.apply_gradients(zip(gradient_g, dcgan.generator.trainable_variables))
猜测原因:
Keras接口有时候会先对数据进行预处理,然后再调用tensorflow的backend,这样会导致函数梯度链断开,无法通过链式求导来进行梯度下降
查看keras源码是怎么定义categorical_crossentropy的:
root@Ie1c58c4ee0020126c:~# find / -iname keras
/usr/local/lib/python3.7/dist-packages/keras
/usr/local/lib/python3.7/dist-packages/tensorflow_core/contrib/keras
/usr/local/lib/python3.7/dist-packages/tensorflow_core/contrib/keras/api/keras
/usr/local/lib/python3.7/dist-packages/tensorflow_core/python/keras
/usr/local/lib/python3.7/dist-packages/tensorflow_core/python/keras/api/_v1/keras
/usr/local/lib/python3.7/dist-packages/tensorflow_core/python/keras/api/_v2/keras
/usr/local/lib/python3.7/dist-packages/tensorflow_core/python/keras/api/keras
vim /usr/local/lib/python3.7/dist-packages/tensorflow_core/python/keras/losses.py
losses.py
def categorical_crossentropy(y_true,
y_pred,
from_logits=False,
label_smoothing=0):
"""Computes the categorical crossentropy loss.
Args:
y_true: tensor of true targets.
y_pred: tensor of predicted targets.
from_logits: Whether `y_pred` is expected to be a logits tensor. By default,
we assume that `y_pred` encodes a probability distribution.
label_smoothing: Float in [0, 1]. If > `0` then smooth the labels.
Returns:
Categorical crossentropy loss value.
"""
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
label_smoothing = ops.convert_to_tensor(label_smoothing, dtype=K.floatx())
def _smooth_labels():
num_classes = math_ops.cast(array_ops.shape(y_true)[1], y_pred.dtype)
return y_true * (1.0 - label_smoothing) + (label_smoothing / num_classes)
y_true = smart_cond.smart_cond(label_smoothing,
_smooth_labels, lambda: y_true)
return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
可以看到在keras的接口里面对数据做了一些预处理然后再调用了tensorflow backend(K)的categorical_crossentropy接口,所以导致了梯度链的断裂,无法通过链式求导和反向传播了更新梯度
解决方法:
自己实现损失函数,或者取keras的losses.py文件中找到源码,直接调用tensorflow backend提供的接口作为损失函数。
最终将代码改为:
from tensorflow.keras import backend as K
def train_generator(x, y, z, eps, dcgan, siamese_model, loss=None):
with tf.GradientTape(persistent=True) as t:
fake_x = dcgan.generator([z, y])
loss_G = -tf.reduce_mean(dcgan.discriminator(fake_x))
preds = aux_model(fake_x)
aux_mean = K.categorical_crossentropy(y, preds)
aux_loss = tf.reduce_mean(aux_mean)
total_loss = aux_loss + loss_G
gradient_g = t.gradient(total_loss, dcgan.generator.trainable_variables)
dcgan.optimizer_G.apply_gradients(zip(gradient_g, dcgan.generator.trainable_variables))
问题完美解决, 相同问题的可以参考下。