筛选样本_困难样本(Hard Sample)处理方法

筛选样本_困难样本(Hard Sample)处理方法_第1张图片

困难样本(Hard Sample)处理方法

如果按照学习的难以来区分,我们的训练集可以分为Hard Sample和Easy Sample. 顾名思义,Hard Sample指的就是难学的样本(loss大),Easy Sample就是好学的样本(loss小)。

举个金融领域的情感分类的例子:

"隔夜要闻:美国12月CPI不及预期 原油终结五连跌"-----Hard Sample
​
"商品期货多数上涨 甲醇上涨5%封涨停"-----Easy Sample

这个概念究竟啥时候提出来的,我也没查到...但是其思想其实在传统机器学习中早有体现,还记得SVM中使用的Hinge Loss,Margin为1的样本其实就是Hard Sample,Margin大于1的就是Easy Sample,只不过这里直接粗暴的对Margin大于1的样本抛弃掉。还有一个更典型的算法——Adaboost,它根据每轮学习的以后的loss大小来动态的改变样本的权重,这其实就是Hard Sample Mining!

如何处理Hard Sample,主要有几种方式:

  • 如果有更多的数据或者获得数据成本较低,请优先增加你的数据量
  • 如有有廉价的方式做Hard Sample数据增强,请优先做
  • 使用能处理Hard Sample的Loss
  • OHEM: Online Hard example mining

Q1: 为什么优先增加你的数据?

一般来说,Hard Sample对应着不平衡样本。试想,如果你的“Hard Sample”占比很大,Diversity很丰富,那么模型还学不好这些样本,就别找借口了,老老实实debug.....原理上讲就是,如果Hard Sample很多,那么它们对于loss和gradient贡献都比较大,模型应该比较容易学好这些样本。

即便样本不平衡下,Easy Sample和Hard Sample比例保持一致的情况下,10个Hard样本和1000个Hard样本是有质的区别的,因为Easy Sample的Diversity增加100倍没啥区别,Hard Sample增加100倍Diversity就有很好的收益。见得多了自然好学一些嘛...那时候说不定,他们已经不再是Hard Sample了!所以,搞定下班咯~

Hard Sample数据增强

在你数据已经一定的情况下,如果你有廉价资源做Hard Sample数据增强,DO IT!典型的场景是CV,你可以使用各种手段(翻转、镜像、旋转、裁剪等等)做数据增强,也可以使用一些生成模型。早在17年就有人用GAN做Hard Sample数据增强了,具体看这篇paper: A-Fast-RCNN: Hard Positive Generation via Adversary for Object Detection,就凭这几年各种GAN的发展,该方法已经非常成熟,略等于免费的蛋糕。

Hard Sample数据增强有两点好处,第一和上面的扩大数据量一致,使得Hard Sample绝对数量增加;第二是可以缩小Easy Sample和Hard Sample比例,使得每个batch,Hard Sample对梯度影响更大,训练效果偏向于Hard Sample。

但是该方法也有弊端:

  • 有一些领域没有廉价的方法做Hard Sample增强,比如典型的NLP。(灵魂吐槽#NLP太难了#)
  • Hard Sample增强的数据容易和原始数据集分布不一致。这条看场景以及模型效果

Loss 改进

其实很多人知道Hard Sample这个词就是从Kaiming大神的Focal Loss吧。

先抛开Focal Loss不谈,看了前面的内容,你应该知道所有应对Hard Sample的Loss改进核心出发点都是一个:增大Loss大的样本对Gradient的贡献

有了这个认识,Focal Loss原理都很简单,它源于目标检测

Focal Loss:

交叉熵损失函数如下

其中ti是第i个样本xi的target,pi是模型预测xi属于类1的概率。

Focal Loss 形式如下:

忽略掉了Focal中加入的class weight,它是用来解决样本不平衡的,在这里不重要,

​是一个超参数。它对CELoss加了一个优化,能够使得Hard Sample对loss贡献更大。看公式,如果pi比较大,即一个Easy Sample,则​就比较小,再​
次方一下,会更小。所以Focal Loss在损失函数上就可以使得Hard Sample在loss中贡献更大,从而使得训练效果对Hard Sample学的更好。

一开始也说了,Hard Sample经常伴随样本不平衡问题,那么其他的loss改进,比如weighted loss, dice loss, Lovasz loss都是应对不平衡问题的。

另一个还需要注意到的是,有时候所谓的Hard Sample是一些标注错误的样本(异常outliers),比如把一个蛋糕标注成了狗。那这种情况下,如果强行当做Hard Sample处理反倒会影响模型的泛化效果。这时候推荐尝试GHM,这里也顺带简单介绍一下GHM吧。

Gradient Harmonizing Mechanism: GHM

先看一张图

筛选样本_困难样本(Hard Sample)处理方法_第2张图片

这张图是一个已经收敛的模型,在测试数据集中的梯度分布。最左边梯度接近0的就是Easy Sample,他们已经学的非常好了;中间的一部分区域对应着不同难度的样本,他们对应着Hard Sample;最后一部分对应着Outliers。其实我们只是希望Hard Sample贡献更多的梯度,但是如果使用Focal Loss,不出意外Outliers会贡献更多的梯度,白话就是我们非要模型把一张误标为狗的蛋糕识别为蛋糕,这肯定是不好的。

GHM就是解决这个问题的,它仿照物理中的密度定义一个梯度密度的概念,把上面的图划分为几个区间,比如梯度范围[0-0.1)有1000个,那么对应区间的梯度就是10000. 然后使用梯度密度的倒数对CE Loss做加权。看上图,最左边的Easy Sample和最右边的Outliers密度较高,所以他们对loss贡献减小,中间的Hard Sample密度小,loss贡献变大。核心思想就这样,具体公式就不说啦。

有一点很重要的是,如果一开始就使用GHM,会使得一开始的收敛速度很慢。原因其实都在原理上。因为一开始模型很垃圾,几乎所有的样本loss都很大,都集中在上图的最右边,所以就导致所有样本的loss权重都很低,前几个epoch收敛慢也就理所当然了。所以正确的做法是,先使用CE Loss训练,几个epoch以后再开始使用GHM

OHEM

最后一种方式就是Online Hard example mining。其实上面说的Focal Loss也算是OHEM一种方式,只不过由于现在OHEM特指一种技巧,所以就一般不这么分。

OHEM核心思想是,train的过程中,每过几个epoch,我计算一下每个样本的loss,然后加入更多Hard Sample。实际中除了训练数据外还有一个Data Pool,放额外的数据,它是用来筛选Hard Sample的池子。当然你也可以采用一种类似与重采样的做法,Training Data和Data Pool一样,每次从Data Pool选择Hard Sample扔进Training Data.

一般两种模式:

  • Replace模式(更常用),筛选出的Hard Sample替换训练数据中的Easy Sample,可以保持训练数据大小不变。
  • Add模式,筛选出的Hard Sample直接加入训练数据。

所以其实OHEM和Focal Loss本质都一样,只不过一个是”硬“一个是”软“的,是不是和Hard Margin SVM和Soft Margin SVM异曲同工呢?是不是还和One-hot Label和Label Smothing异曲同工呢?是不是还和Hard Attention和Soft Attention异曲同工呢?其实”软"和"硬”在机器学习中是非常常见的两种模式。

最后要说明,没事别上OHEM!因为它实在是太慢了!!原因是我们做inference时候,loss都是一个batch算的,不管你是sum还是avg,都没办法区分一个batch中的样本了。所以ohem的做法就很粗暴了,使用batch size=1做Hard Sample的筛选,如果你的样本池很大,那是非常非常慢的,谁用谁知道....

照应一下开头,本文讨论Hard Sample优化问题,方法如下:

  • 如果有更多的数据或者获得数据成本较低,请优先增加你的数据量
  • 如有有廉价的方式做Hard Sample数据增强,请优先做
  • 使用能处理Hard Sample的Loss
  • OHEM: Online Hard example mining

最后祝大家新年快乐~~~~

你可能感兴趣的:(筛选样本)