【论文笔记】The Impact of Imbalanced Training Data for CNN


原文是:《The Impact of Imbalanced Training Data for Convolutional Neural Networks》

本博客是该论文的阅读笔记,不免有很多细节不对之处。

还望各位看官能够见谅,欢迎批评指正。

更多相关博客请猛戳:http://blog.csdn.net/cyh_24

如需转载,请附上本文链接:http://blog.csdn.net/cyh_24/article/details/49871387

Abstract

本文主要研究使用不平衡数据训练CNN对图像分类的影响。文中使用的数据集是CIFAR-10,作者使用这个数据库,人工地对不同类别生成不同数量分布的数据。比如,让一个类别的图像占很大的比例,而另一类占很小的比例。使用这些生成的不同的训练集,均去训练一个CNN,并测试得到相应的准确率。

结果显示,不平衡训练集会对结果造成很大的负面影响,而训练集在平衡的情况下,能够达到最好的performance。

并且,文中得出一个结论:oversampling是一个很好的效的方式来解决不平衡训练集的问题。

实验过程

Dataset

使用的数据集是CIFAR-10,该数据集有10个类,每类6000张,共6w张图像。

对CIFAR-10进行数据切分,使用其中的5000张作为训练,1000作为测试图像。

生成不同数据分布

解释一下上图:

  • Dist.1 是balanced data,每个类都占10%比重;
  • Dist.2表明airplane,automobile,bird和cat各占8%,而其他类别各占12%…这个应该能看懂吧。

所以,现在有了11个训练集,接下来使用相同的CNN来训练,还是使用原来的test data进行测试。

Oversampling

文中使用的oversampling方式非常简单:

对于每一类,随机选出一些图片进行复制,直到该类图片数量与占最大比重的图片相等。

Results

Distribution Performace

Oversampling Performance

以上是经过oversampling之后的训练的CNN的performance,可以看出,几乎每个类都有提升,不过Dist.1(balanced training data)还是最高的。

Total Performance

平均以下每个Dist的准确率,得到如下表所示的准确率比较图,深色是imbalanced 的准确率,浅色是oversampling之后的准确率。

文章目标很明确,思路也很简单,并没有其他trick,我也就讲到这了。

总结一下,文章讲的事情和结论:

  1. 训练数据分布情况对CNN结果产生很大影响;
  2. 显然,balanced训练集是最优的,数据越不平衡,准确率越差;
  3. 使用Oversampling能够提升准确率;

你可能感兴趣的:(【机器学习&深度学习】,游戏编程模式)