TF2.x `class_weight` 导致训练集和验证集损失有很大差异的问题

文章目录

  • TL;DR
  • 项目背景
  • `class_weight` 如何作用到数据集上
  • `sample_weight` 如何作用到损失函数上
  • 为什么训练接 和 验证集的损失会相对于 `class_weight` 变化?

同步一下最新结论. 大佬们认为 验证集应该与数据本身的特性保持一致(不作增强和采样等处理). 相对于模型在新数据上的表现, 训练集和验证集的损失差显得不重要.

当处理 Imbalanced Classification 问题时, 一种常规做法是对不同的类别的损失进行加权. 在 TF2.x 中, 这可以通过设置 Model.fit()class_weight 参数来实现. 常规的 class_weight 可以是 不同类别样本个数的倒数(适当缩放).

但是在实验中发现, 一旦设置了 class_weight 参数, 训练集 和 验证集 的损失会有较大的差异. 那么为什么 class_weight 会导致训练集和验证集的损失差异呢?

TL;DR

先说结论:

在 TF2.x 的实现中, class_weight 只作用于训练集, 而不作用于 验证集, 导致了设置了 class_weight 参数时 训练集 和 验证集 的损失差异较大. 相关讨论已发起 tensorflow/tensorflow:issue#42647 讨论.

项目背景

首先, 简单介绍一下项目情况.

这是一个 五分类问题, 各个类别间的样本量差距较大, 分别为:

 {
        "样本总量": 479105,
        "class_0": 31815,
        "class_1": 8498,
        "class_2": 4704,
        "class_3": 3370,
        "class_4": 430718
    }

当类别权重为 1:1:1:1:1 时, 训练集 和 验证集 的损失基本一致.

当类别权重为 4:8:10:12:1 (样本量倒数平方根), 损失差距为 1.5 : 0.5.

当样本权重为 15:56:101:142:1 (样本量倒数), 损失差距为 7 : 1.

当类别权重为 227: 3179: 10374: 20212: 1 (样本量倒数平方), 损失差距为 300+ : 5.

Epoch 2/10000
128/128 [==============================] - ETA: 0s - loss: 1.4951 - acc: 0.8798 - acc_f1: 0.2661 - cm: 10485.7598
128/128 [==============================] - 64s 496ms/step - loss: 1.4951 - acc: 0.8798 - acc_f1: 0.2661 - cm: 10485.7598 - val_loss: 0.4941 - val_acc: 0.8888 - val_acc_f1: 0.2749 - val_cm: 2621.4399
Epoch 3/10000
128/128 [==============================] - ETA: 0s - loss: 1.4541 - acc: 0.8831 - acc_f1: 0.2735 - cm: 10485.7598
128/128 [==============================] - 63s 494ms/step - loss: 1.4541 - acc: 0.8831 - acc_f1: 0.2735 - cm: 10485.7598 - val_loss: 0.5185 - val_acc: 0.8734 - val_acc_f1: 0.2765 - val_cm: 2621.4399
Epoch 2/10000
128/128 [==============================] - ETA: 0s - loss: 379.6439 - acc: 0.0073 - acc_f1: 0.0077 - cm: 10485.7598
128/128 [==============================] - 63s 489ms/step - loss: 379.6439 - acc: 0.0073 - acc_f1: 0.0077 - cm: 10485.7598 - val_loss: 4.9299 - val_acc: 0.0076 - val_acc_f1: 0.0081 - val_cm: 2621.4399
Epoch 3/10000
128/128 [==============================] - ETA: 0s - loss: 374.9651 - acc: 0.0078 - acc_f1: 0.0099 - cm: 10485.7598
128/128 [==============================] - 63s 492ms/step - loss: 374.9651 - acc: 0.0078 - acc_f1: 0.0099 - cm: 10485.7598 - val_loss: 4.8990 - val_acc: 0.0078 - val_acc_f1: 0.0120 - val_cm: 2621.4399

从实验结果来看, 训练集和验证集的 Loss 不一致, 与 class_weight 有关, 与 batch_size / steps_per_epochs 的参数无关.

为什么 class_weight 会影响到训练集和验证集的 Loss 呢?

class_weight 如何作用到数据集上

首先看 class_weight 是如何作用到数据集上的.

TF2 定义了 DataAdapter(python.keras.engine.data_adapter.DataAdapter) 来适配不同的输入数据格式. Data Adapter 通过 can_handle 方法注册自己能处理的输入数据类型. 比如 DatasetAdapter(python.keras.engine.data_adapter.DatasetAdapter), 其 can_handle 方法内容如下:

@staticmethod
def can_handle(x, y=None):
return (isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)) or
        _is_distributed_dataset(x))

表明该数据适配器可以适配 DatasetV1DatasetV2DistributedDatasetDistributedDatasetsFromFunction 数据.

而这些不同的 DataAdapter 统一由 DataHandler(python.keras.engine.data_adapter.DataHandler) 管理. DataHandler 通过 select_data_adapter(x, y) 选择合适的数据适配器. select_data_adapter 本质上是遍历所有数据适配器, 找到能够适配当前输入数据 x, y 的数据适配器.

数据集中的数据最终会通过 DataHandlerenumerate_epochs() 方法暴露给调用者, enumerate_epochs 实质上只是对 DataHandler._dataset 的迭代. DataHandler._dataset 就是 class_weight 作用的对象. DataHandler 首先会调用数据适配器 DataAdapter.get_dataset() 方法获得底层数据集, 然后通过 dataset = dataset.map(_make_class_weight_map_fn(class_weight))class_weight 转换成每个样本的 sample_weight, 其具体转换过程是(代码有省略):

def _make_class_weight_map_fn(class_weight):
  class_ids = list(sorted(class_weight.keys()))
  class_weight_tensor = ops.convert_to_tensor_v2(
      [class_weight[int(c)] for c in class_ids])

  def _class_weights_map_fn(*data):
    x, y, sw = unpack_x_y_sample_weight(data)

    y_classes = smart_cond.smart_cond(
        y.shape.rank == 2 and backend.shape(y)[1] > 1,
        lambda: backend.argmax(y, axis=1),
        lambda: math_ops.cast(backend.reshape(y, (-1,)), dtypes.int64))

    cw = array_ops.gather_v2(class_weight_tensor, y_classes)
    if sw is not None:
      cw = math_ops.cast(cw, sw.dtype)
      sw, cw = expand_1d((sw, cw))
      # `class_weight` and `sample_weight` are multiplicative.
      sw = sw * cw
    else:
      sw = cw

    return x, y, sw

  return _class_weights_map_fn

需要说明的是,

  • 当原始输入数据是 Dataset 时, 是不支持 sample_weight 参数的 (参考 python.keras.engine.data_adapter.DatasetAdapter._validate_args)
  • class_weight 参数是不支持 多输出模型(Multi-Output Model) 的, 也不支持标签的秩大于 2 的数据集, 即只支持 [batch_size][batch_size, num_classes] 两种格式 (参考 python.keras.engine.data_adapter._make_class_weight_map_fn)
  • 如果数据集同时指定了 sample_weightclass_weight, 则两则会共同作用于数据

sample_weight 如何作用到损失函数上

其次, 看看 sample_weight 是如何作用到 损失函数 上的.

模型的损失函数是通过 compile() 方法设置的, Model.compiled_loss 是模型的 loss 容器, 而 Model.loss 则是模型的原始损失. 训练集的损失最终是在 train_step() 方法里计算的, 计算逻辑是:

data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
with backprop.GradientTape() as tape:
  y_pred = self(x, training=True)
  loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)

对应的验证集损失则是在 test_step() 方法里计算的, 其计算逻辑是:

data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
y_pred = self(x, training=False)
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)

由上的对比可以发现, 训练集和验证集的损失计算方式是一致的.

事实上, sample_weight 并不会直接作用到 compile() 提供的具体的 loss (如 CrossEntropy 等) 上, 而是由 python.keras.losses.Loss.__call__() 对 具体 Loss 返回的样本损失(sample losses) 按照 sample_weight 进行加权, 具体可以参考 Loss.__call__() 源码:

  def __call__(self, y_true, y_pred, sample_weight=None):
    graph_ctx = tf_utils.graph_context_for_symbolic_tensors(
        y_true, y_pred, sample_weight)
    with K.name_scope(self._name_scope), graph_ctx:
      ag_call = autograph.tf_convert(self.call, ag_ctx.control_status_ctx())
      
      # 获取单个样本的损失
      losses = ag_call(y_true, y_pred)

      # 按照 sample_weight 进行加权
      return losses_utils.compute_weighted_loss(
          losses, sample_weight, reduction=self._get_reduction())

为什么训练接 和 验证集的损失会相对于 class_weight 变化?

继续跟进 losses_utils.compute_weighted_loss() 的源码, 可以发现 sample_weight 加权的是这样做的 weighted_losses = math_ops.multiply(losses, sample_weight), 对于加权后的损失是 loss = reduce_weighted_loss(weighted_losses, reduction) 进行均值的, 具体代码是:

def reduce_weighted_loss(weighted_losses,
                         reduction=ReductionV2.SUM_OVER_BATCH_SIZE):
  """Reduces the individual weighted loss measurements."""
  if reduction == ReductionV2.NONE:
    loss = weighted_losses
  else:
    loss = math_ops.reduce_sum(weighted_losses)
    if reduction == ReductionV2.SUM_OVER_BATCH_SIZE:
      loss = _safe_mean(loss, _num_elements(weighted_losses))
  return loss

其中分母 _num_elements(weighted_losses) 实质上是 tf.size(weighted_losses), 即 weighted_losses 的元素个数. 显然, 当 sample_weight 变化时, loss 的绝对值会跟着变化. 感觉这里最好使用 math_ops.reduce_sum(sample_weight) 来规范化 loss 比较好.

但实际上, 造成训练集 和 验证集损失不同的原因却不在这里, 通过跟进 Model.fit 代码可以发现, class_weight 压根没有作用于验证集, 代码如下:

# for training dataset
data_handler = data_adapter.DataHandler(
  x=x,
  y=y,
  sample_weight=sample_weight,
  batch_size=batch_size,
  steps_per_epoch=steps_per_epoch,
  initial_epoch=initial_epoch,
  epochs=epochs,
  shuffle=shuffle,
  class_weight=class_weight,
  max_queue_size=max_queue_size,
  workers=workers,
  use_multiprocessing=use_multiprocessing,
  model=self,
  steps_per_execution=self._steps_per_execution)

# for validation dataset
self._eval_data_handler = data_adapter.DataHandler(
    x=val_x,
    y=val_y,
    sample_weight=val_sample_weight,
    batch_size=validation_batch_size or batch_size,
    steps_per_epoch=validation_steps,
    initial_epoch=0,
    epochs=1,
    max_queue_size=max_queue_size,
    workers=workers,
    use_multiprocessing=use_multiprocessing,
    model=self,
    steps_per_execution=self._steps_per_execution)

因此 class_weight 对验证集是无影响的, 自然也就对验证集的损失没有影响. 既然 sample_weight 可以作用于验证集, 为什么 class_weight 不可以呢?

那么 class_weight 应该作用于验证集吗? 至少从保持 训练集 和 验证集 的损失的一致性这个点出发, class_weight 是应该作用于验证集的.

相关讨论已发起 tensorflow/tensorflow:issue#42647, 看看大佬们怎么看待这个问题.

你可能感兴趣的:(TensorFlow,2.x,源码学习,tensorflow,深度学习)