当处理 Imbalanced Classification 问题时, 一种常规做法是对不同的类别的损失进行加权. 在 TF2.x 中, 这可以通过设置 Model.fit()
的 class_weight
参数来实现. 常规的 class_weight
可以是 不同类别样本个数的倒数(适当缩放).
但是在实验中发现, 一旦设置了 class_weight
参数, 训练集 和 验证集 的损失会有较大的差异. 那么为什么 class_weight
会导致训练集和验证集的损失差异呢?
先说结论:
在 TF2.x 的实现中, class_weight
只作用于训练集, 而不作用于 验证集, 导致了设置了 class_weight
参数时 训练集 和 验证集 的损失差异较大. 相关讨论已发起 tensorflow/tensorflow:issue#42647 讨论.
首先, 简单介绍一下项目情况.
这是一个 五分类问题, 各个类别间的样本量差距较大, 分别为:
{
"样本总量": 479105,
"class_0": 31815,
"class_1": 8498,
"class_2": 4704,
"class_3": 3370,
"class_4": 430718
}
当类别权重为 1:1:1:1:1
时, 训练集 和 验证集 的损失基本一致.
当类别权重为 4:8:10:12:1
(样本量倒数平方根), 损失差距为 1.5 : 0.5
.
当样本权重为 15:56:101:142:1
(样本量倒数), 损失差距为 7 : 1
.
当类别权重为 227: 3179: 10374: 20212: 1
(样本量倒数平方), 损失差距为 300+ : 5
.
Epoch 2/10000
128/128 [==============================] - ETA: 0s - loss: 1.4951 - acc: 0.8798 - acc_f1: 0.2661 - cm: 10485.7598
128/128 [==============================] - 64s 496ms/step - loss: 1.4951 - acc: 0.8798 - acc_f1: 0.2661 - cm: 10485.7598 - val_loss: 0.4941 - val_acc: 0.8888 - val_acc_f1: 0.2749 - val_cm: 2621.4399
Epoch 3/10000
128/128 [==============================] - ETA: 0s - loss: 1.4541 - acc: 0.8831 - acc_f1: 0.2735 - cm: 10485.7598
128/128 [==============================] - 63s 494ms/step - loss: 1.4541 - acc: 0.8831 - acc_f1: 0.2735 - cm: 10485.7598 - val_loss: 0.5185 - val_acc: 0.8734 - val_acc_f1: 0.2765 - val_cm: 2621.4399
Epoch 2/10000
128/128 [==============================] - ETA: 0s - loss: 379.6439 - acc: 0.0073 - acc_f1: 0.0077 - cm: 10485.7598
128/128 [==============================] - 63s 489ms/step - loss: 379.6439 - acc: 0.0073 - acc_f1: 0.0077 - cm: 10485.7598 - val_loss: 4.9299 - val_acc: 0.0076 - val_acc_f1: 0.0081 - val_cm: 2621.4399
Epoch 3/10000
128/128 [==============================] - ETA: 0s - loss: 374.9651 - acc: 0.0078 - acc_f1: 0.0099 - cm: 10485.7598
128/128 [==============================] - 63s 492ms/step - loss: 374.9651 - acc: 0.0078 - acc_f1: 0.0099 - cm: 10485.7598 - val_loss: 4.8990 - val_acc: 0.0078 - val_acc_f1: 0.0120 - val_cm: 2621.4399
从实验结果来看, 训练集和验证集的 Loss 不一致, 与 class_weight
有关, 与 batch_size
/ steps_per_epochs
的参数无关.
为什么 class_weight
会影响到训练集和验证集的 Loss 呢?
class_weight
如何作用到数据集上首先看 class_weight
是如何作用到数据集上的.
TF2 定义了 DataAdapter(python.keras.engine.data_adapter.DataAdapter)
来适配不同的输入数据格式. Data Adapter 通过 can_handle
方法注册自己能处理的输入数据类型. 比如 DatasetAdapter(python.keras.engine.data_adapter.DatasetAdapter)
, 其 can_handle
方法内容如下:
@staticmethod
def can_handle(x, y=None):
return (isinstance(x, (dataset_ops.DatasetV1, dataset_ops.DatasetV2)) or
_is_distributed_dataset(x))
表明该数据适配器可以适配 DatasetV1
和 DatasetV2
及 DistributedDataset
和 DistributedDatasetsFromFunction
数据.
而这些不同的 DataAdapter 统一由 DataHandler(python.keras.engine.data_adapter.DataHandler)
管理. DataHandler 通过 select_data_adapter(x, y)
选择合适的数据适配器. select_data_adapter
本质上是遍历所有数据适配器, 找到能够适配当前输入数据 x, y
的数据适配器.
数据集中的数据最终会通过 DataHandler
的 enumerate_epochs()
方法暴露给调用者, enumerate_epochs
实质上只是对 DataHandler._dataset
的迭代. DataHandler._dataset
就是 class_weight
作用的对象. DataHandler
首先会调用数据适配器 DataAdapter.get_dataset()
方法获得底层数据集, 然后通过 dataset = dataset.map(_make_class_weight_map_fn(class_weight))
将 class_weight
转换成每个样本的 sample_weight
, 其具体转换过程是(代码有省略):
def _make_class_weight_map_fn(class_weight):
class_ids = list(sorted(class_weight.keys()))
class_weight_tensor = ops.convert_to_tensor_v2(
[class_weight[int(c)] for c in class_ids])
def _class_weights_map_fn(*data):
x, y, sw = unpack_x_y_sample_weight(data)
y_classes = smart_cond.smart_cond(
y.shape.rank == 2 and backend.shape(y)[1] > 1,
lambda: backend.argmax(y, axis=1),
lambda: math_ops.cast(backend.reshape(y, (-1,)), dtypes.int64))
cw = array_ops.gather_v2(class_weight_tensor, y_classes)
if sw is not None:
cw = math_ops.cast(cw, sw.dtype)
sw, cw = expand_1d((sw, cw))
# `class_weight` and `sample_weight` are multiplicative.
sw = sw * cw
else:
sw = cw
return x, y, sw
return _class_weights_map_fn
需要说明的是,
Dataset
时, 是不支持 sample_weight
参数的 (参考 python.keras.engine.data_adapter.DatasetAdapter._validate_args
)class_weight
参数是不支持 多输出模型(Multi-Output Model) 的, 也不支持标签的秩大于 2 的数据集, 即只支持 [batch_size]
和 [batch_size, num_classes]
两种格式 (参考 python.keras.engine.data_adapter._make_class_weight_map_fn
)sample_weight
和 class_weight
, 则两则会共同作用于数据sample_weight
如何作用到损失函数上其次, 看看 sample_weight
是如何作用到 损失函数 上的.
模型的损失函数是通过 compile()
方法设置的, Model.compiled_loss
是模型的 loss
容器, 而 Model.loss
则是模型的原始损失. 训练集的损失最终是在 train_step()
方法里计算的, 计算逻辑是:
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
with backprop.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
对应的验证集损失则是在 test_step()
方法里计算的, 其计算逻辑是:
data = data_adapter.expand_1d(data)
x, y, sample_weight = data_adapter.unpack_x_y_sample_weight(data)
y_pred = self(x, training=False)
self.compiled_loss(y, y_pred, sample_weight, regularization_losses=self.losses)
由上的对比可以发现, 训练集和验证集的损失计算方式是一致的.
事实上, sample_weight
并不会直接作用到 compile()
提供的具体的 loss (如 CrossEntropy 等) 上, 而是由 python.keras.losses.Loss.__call__()
对 具体 Loss 返回的样本损失(sample losses) 按照 sample_weight
进行加权, 具体可以参考 Loss.__call__()
源码:
def __call__(self, y_true, y_pred, sample_weight=None):
graph_ctx = tf_utils.graph_context_for_symbolic_tensors(
y_true, y_pred, sample_weight)
with K.name_scope(self._name_scope), graph_ctx:
ag_call = autograph.tf_convert(self.call, ag_ctx.control_status_ctx())
# 获取单个样本的损失
losses = ag_call(y_true, y_pred)
# 按照 sample_weight 进行加权
return losses_utils.compute_weighted_loss(
losses, sample_weight, reduction=self._get_reduction())
class_weight
变化?继续跟进 losses_utils.compute_weighted_loss()
的源码, 可以发现 sample_weight
加权的是这样做的 weighted_losses = math_ops.multiply(losses, sample_weight)
, 对于加权后的损失是 loss = reduce_weighted_loss(weighted_losses, reduction)
进行均值的, 具体代码是:
def reduce_weighted_loss(weighted_losses,
reduction=ReductionV2.SUM_OVER_BATCH_SIZE):
"""Reduces the individual weighted loss measurements."""
if reduction == ReductionV2.NONE:
loss = weighted_losses
else:
loss = math_ops.reduce_sum(weighted_losses)
if reduction == ReductionV2.SUM_OVER_BATCH_SIZE:
loss = _safe_mean(loss, _num_elements(weighted_losses))
return loss
其中分母 _num_elements(weighted_losses)
实质上是 tf.size(weighted_losses)
, 即 weighted_losses
的元素个数. 显然, 当 sample_weight
变化时, loss
的绝对值会跟着变化. 感觉这里最好使用 math_ops.reduce_sum(sample_weight)
来规范化 loss
比较好.
但实际上, 造成训练集 和 验证集损失不同的原因却不在这里, 通过跟进 Model.fit
代码可以发现, class_weight
压根没有作用于验证集, 代码如下:
# for training dataset
data_handler = data_adapter.DataHandler(
x=x,
y=y,
sample_weight=sample_weight,
batch_size=batch_size,
steps_per_epoch=steps_per_epoch,
initial_epoch=initial_epoch,
epochs=epochs,
shuffle=shuffle,
class_weight=class_weight,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
model=self,
steps_per_execution=self._steps_per_execution)
# for validation dataset
self._eval_data_handler = data_adapter.DataHandler(
x=val_x,
y=val_y,
sample_weight=val_sample_weight,
batch_size=validation_batch_size or batch_size,
steps_per_epoch=validation_steps,
initial_epoch=0,
epochs=1,
max_queue_size=max_queue_size,
workers=workers,
use_multiprocessing=use_multiprocessing,
model=self,
steps_per_execution=self._steps_per_execution)
因此 class_weight
对验证集是无影响的, 自然也就对验证集的损失没有影响. 既然 sample_weight
可以作用于验证集, 为什么 class_weight
不可以呢?
那么 class_weight
应该作用于验证集吗? 至少从保持 训练集 和 验证集 的损失的一致性这个点出发, class_weight
是应该作用于验证集的.
相关讨论已发起 tensorflow/tensorflow:issue#42647, 看看大佬们怎么看待这个问题.