Tensorflow中tf.keras.metrics.MeanIoU在shape不一致错误

TensorFlow版本:在2.4 和 2.5上这样改就可以(已测试)
还有其他版本好像是调用 call 方法实现的IoU,所以需要对应需要修改 call 函数

Tensorflow中tf.keras.metrics.MeanIoU在预测返回值为one-hot编码的情况下使用IoU

class MeanIoU(tf.keras.metrics.MeanIoU): 
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.argmax(y_pred, axis=-1)
        return super().update_state(y_true, y_pred, sample_weight=sample_weight)

model.compile(optimizer='adam',
             loss='sparse_categorical_crossentropy',
             metrics=['acc', MeanIoU(num_classes=34)])

修改原因:

IoU 算法计算在 tf.keras.metrics.MeanIoU.update_state 函数中执行计算(计算正确率),
我们的网络最后输出的的预测值是one-hot-encode的(例如使用最后一层使用 ‘softmax’ 激活的网络)。
同时我们的Label数据集采用 label-encode,此时输入进入 update_state 的 y_true 和 y_pred 参数在shape上不一致,就会导致错误。 所以进行此操作,将y_pred转化为合适的格式。

你可能感兴趣的:(tensorflow)