如何处理不平衡数据集的分类任务

在情感分类任务中,数据集的标签分布往往是极度不平衡的。以我目前手上的这个二分类任务来说,正例样本14.4万个:负例样本166.1万 = 1 :11.5。很显然这是一个极度不平衡的数据集,假设我把样本全部预测为负,那准确率也高达92%,但这么做没有意义。

那么我们如何处理这个不平衡数据集呢?

因为我用的是神经网络,我不希望减少训练样本,因此我不会采用下采样的方式。有三个方向可以尝试:

  • 使用自定义的loss函数
  • 设置class weight
  • 设置sample weight

这里,我将尝试三种不同的loss函数并进行对比。

(一)三种损失函数

令表示样本的真实标签,则。令表示sigmoid输出的预测类别为1的概率,显然。下面我们给出三种损失函数的定义。

1. Binary crossentropy

2. 修正的Binary crossentropy

这个损失函数来自苏神的两篇文章:
【1】文本情感分类(四):更好的损失函数
【2】何恺明大神的「Focal Loss」,如何更好地理解?

引入单位跃阶函数

取定阈值(为可调超参数,原则上大于0.5均可),则:

这里我稍微修改了一点点,以使得损失函数更加对称。

这个损失函数跟Binary Crossentropy比起来,就是多了这个调节因子,我们来分析一下这个公式:

  • 当正样本的预测概率大于m时,根据的定义,这一项的损失就变为了0;当正样本的预测概率小于m时,保持这一项损失不变;
  • 当负样本的预测概率小于1-m时,这一项的损失也变为了0;当负样本的预测概率大于1-m时,保持这一项损失不变。

也就是说,这个损失函数将焦点放在了分类错误的样本上面,希望能够把更多的样本正确分类。

3. Focal Loss

来自论文Focal Loss for Dense Object Detection


其中为权重因子,为调节参数。

我们来分析一下这个函数,考虑的情形:

  • 当正样本的预测概率接近1时(我们希望的),则接近0,则就变得很小很小。也就是说,当某个样本分类合理时,函数会对其损失进行打折(down weighting),则打折的幅度依赖于参数。
  • 当正样本的预测概率接近0时(我们不希望的),则接近1,因此会变小一点点,但跟上面的情况比起来,其实相当于是放大了,因为大小是相对的。

对于负样本,同理可分析,此处略过。

再来考虑这个参数,它其实是一个权重的调节因子,用于平衡正负样本的损失贡献。但由于的存在,我们很难从实际数据中得到指导来设置这个参数,更多可能要去尝试和调参。一般情况下,先令。

Focal Loss函数对容易分类的样本进行down weighting,聚焦于难分类的样本上。跟苏神的那个思路类似,却更加高明。

那么在类别不均衡的分类任务中,这个损失函数到底怎么起作用呢?

我们知道, 而我的任务中负样本占了绝大多数,对模型来说,它们绝大部分是很好分类的样本,因此它们的损失贡献会大打折扣,从而使模型聚焦在难分类的样本上面,包括绝大部分的正样本。同时,这个参数也能起到平衡正负样本损失的作用。

(二) 损失函数代码

1. 修正的Binary crossentropy(keras版本)

import keras.backend as K

margin = 0.8
theta = lambda t: (K.sign(t)+1.)/2.

def variant_crossentropy_loss(y_true, y_pred):
    return  - theta(margin - y_pred) * y_true * K.log(y_pred + 1e-9)
            - theta(y_pred - 1 + m) * (1 - y_true) * K.log(1 - y_pred + 1e-9))

2. Focal Loss(tensorflow版本)

由于网上的代码都是多分类的(基于softmax输出的),这里我写了一个二分类的(基于sigmoid输出),同时我还加了一个rescale的flag来控制损失函数的量级,单任务学习中,这个flag按照默认的False即可。

import tensorflow as tf

def variant_focal_loss(gamma=2., alpha=0.5, rescale = False):

    gamma = float(gamma)
    alpha = float(alpha)

    def focal_loss_fixed(y_true, y_pred):
        """
        Focal loss for bianry-classification
        FL(p_t)=-rescaled_factor*alpha_t*(1-p_t)^{gamma}log(p_t)
        
        Notice: 
        y_pred is probability after sigmoid

        Arguments:
            y_true {tensor} -- groud truth label, shape of [batch_size, 1]
            y_pred {tensor} -- predicted label, shape of [batch_size, 1]

        Keyword Arguments:
            gamma {float} -- (default: {2.0})  
            alpha {float} -- (default: {0.5})

        Returns:
            [tensor] -- loss.
        """
        epsilon = 1.e-9  
        y_true = tf.convert_to_tensor(y_true, tf.float32)
        y_pred = tf.convert_to_tensor(y_pred, tf.float32)
        model_out = tf.clip_by_value(y_pred, epsilon, 1.-epsilon)  # to advoid numeric underflow
        
        # compute cross entropy ce = ce_0 + ce_1 = - (1-y)*log(1-y_hat) - y*log(y_hat)
        ce_0 = tf.multiply(tf.subtract(1., y_true), -tf.log(tf.subtract(1., model_out)))
        ce_1 = tf.multiply(y_true, -tf.log(model_out))

        # compute focal loss fl = fl_0 + fl_1
        # obviously fl < ce because of the down-weighting, we can fix it by rescaling
        # fl_0 = -(1-y_true)*(1-alpha)*((y_hat)^gamma)*log(1-y_hat) = (1-alpha)*((y_hat)^gamma)*ce_0
        fl_0 = tf.multiply(tf.pow(model_out, gamma), ce_0)
        fl_0 = tf.multiply(1.-alpha, fl_0)
        # fl_1= -y_true*alpha*((1-y_hat)^gamma)*log(y_hat) = alpha*((1-y_hat)^gamma*ce_1
        fl_1 = tf.multiply(tf.pow(tf.subtract(1., model_out), gamma), ce_1)
        fl_1 = tf.multiply(alpha, fl_1)
        fl = tf.add(fl_0, fl_1)
        f1_avg = tf.reduce_mean(fl)
        
        if rescale:
            # rescale f1 to keep the quantity as ce
            ce = tf.add(ce_0, ce_1)
            ce_avg = tf.reduce_mean(ce)
            rescaled_factor = tf.divide(ce_avg, f1_avg + epsilon)
            f1_avg = tf.multiply(rescaled_factor, f1_avg)
        
        return f1_avg
    
    return focal_loss_fixed

(三)结果对比

我采用了双向GRU模型,在保持模型以及数据不变的情况下,仅改变损失函数以对比不同损失函数在我任务上的表现,由于Local CV是我关注的最终metric,我使用Local CV作为early stopping的依据,若两个epoch后Local CV没有得到提升,则模型停止训练,并取Local CV最高的模型作为预测模型。

以下是各个损失函数在验证集上表现,我取了三个维度:

  • Accuracy(阈值为0.5)
  • AUC(Area Under the Curve)
  • Local CV(本质是多个AUC的加权平均)
如何处理不平衡数据集的分类任务_第1张图片

根据以上数据,我们可以得到如下结论:

  1. Focal Loss在我的任务上获得了最大的Local CV,比带class weight的Binary Crossentropy损失高出3个千分点。Focal Loss真是一个优秀的损失函数!
  2. Variant Crossentropy这个损失函数获得了最高的Accuracy,但是AUC和Local CV都很低,显然不适合我手中的任务。根据Variant Crossentropy的公式,其实也可推断,这个损失函数是在优化Accuracy。对于关注正确率的任务,这个损失函数应该是不错的选择。

参考资料:

【1】 非平衡数据集 focal loss 多类分类
【2】Focal Loss for Dense Object Detection
【3】文本情感分类(四):更好的损失函数
【4】何恺明大神的「Focal Loss」,如何更好地理解?

你可能感兴趣的:(如何处理不平衡数据集的分类任务)