【导读】对于机器学习而言,获取数据的成本有时会非常昂贵,因此为模型选择一个合理的训练数据规模,对于机器学习是至关重要的。在本文中,作者针对线性回归模型和深度学习模型,分别介绍了确定训练数据集规模的方法。
数据是否会成为新时代的“原油”是人们近来常常争论的一个问题。
无论争论结果如何,可以确定的是,在机器学前期,数据获取成本可能十分高昂(人力工时、授权费、设备运行成本等)。因此,对于机器学习的一个非常关键的问题是,确定能使模型达到某个特定目标(如分类器精度)所需要的训练数据规模。
在本文中,我们将对经验性结果和研究文献中关于训练数据规模的讨论进行简明扼要的综述,涉及的机器学习模型包括回归分析等基本模型,以及复杂模型如深度学习。训练数据规模在文献中也称样本复杂度,本文将对如下内容进行介绍:
- 针对线性回归和计算机视觉任务,给出基于经验确定训练数据规模的限制;
- 讨论如何确定样本大小,以获得更好的假设检验结果。虽然这是一个统计问题,但是该问题和确定机器学习训练数据集规模的问题很相似,因此在这里一并讨论;
- 对影响训练数据集规模的因素,给出基于统计理论学习的结果;
- 探讨训练集增大对模型表现提升的影响,并着重分析深度学习中的情形;
- 给出一种在分类任务中确定训练数据集大小的方法;
- 探讨增大训练集是否是应对不平衡数据集的最好方式。
基于经验确定训练集规模的限制
首先,我们依据使用的模型类型,探讨一些广泛使用的经验性方法:
通过假设检验确定样本规模
假设检验是数据科学常用的一种统计工具,一般也可以用于确定样本规模。
举个例子:某科技巨头搬去 A 城后,A 城的房价便急剧上涨,而某记者想知道现在每套公寓的均价是多少。那么问题来了,在保证 95% 的置信度,60 K 的公寓价格标准差,且价格误差在10K 以内的条件下,计算多少栋公寓的均价较为合理?
相应公式见下图,其中 N 为所需的样本规模,1.96 为标准正态分布在 95% 置信度下所对应的常数:
样本量估计
根据上述公式,该记者需要考虑大概 138 栋公寓的价格。
该公式将随着检验问题的不同而改变,但是都要通过置信区间、可容忍误差和标准差值来计算。
训练数据规模的统计学习理论
我们先介绍一下著名的 VC 维(Vapnik-Chevronenkis dimension)。VC 维是一种模型复杂度的度量;模型越复杂,它的 VC 维就越高。下面介绍根据 VC 维来确定训练数据规模的公式。
首先,通过一个例子来看一下 VC 维是如何计算的:假设一个二维平面上有三个点需要被分类,而我们的分类器为该平面上的一条直线。无论这三点怎样组合(均为正例,两正一负、一正两负等),这条直线都能正确地将正负样本归类/分开。那么,我们就认为一个线性分类器可以划分这三点中的任意一点,因而它的 VC 维至少为 3。
另外,由于存在四个点的组合不能被一条直线准确分开,所以这个线性分类器的 VC 维为 3。可以证明,训练数据规模 N 是 VC 维的一个函数:
由 VC 维估计训练数据规模
其中 d 为失败率, epsilon 为学习中的误差率。由此可见,学习模型所需的样本量取决于模型的复杂度。但该方法有一个弊端,就是在面对神经网络显著的复杂度时,会要求十分庞大的训练数据规模。
当训练集增大时,模型的表现会持续提升吗?在深度学习任务又如何呢?
上图展示了随着数据规模的增长,传统的机器学习算法(回归等)和深度学习表现的变化。
具体来看,对于传统的机器学习算法,模型的表现先是遵循幂定律(power law),之后趋于平缓;而对于深度学习,该问题还在持续不断地研究中,不过图一为目前较为一致的结论,即随着数据规模的增长,深度学习模型的表现会按照幂定律持续提升。例如,有人曾用深度学习方法对三亿张图像进行分类,发现模型的表现随着训练数据规模的增长按对数关系提升。
值得注意的是,在深度学习中也有一些与上述例子不同的结果。比如,在《Learning Visual Features from Large Weakly Supervised Data》一文中,作者使用了一亿条 Flickr 上的图片和标签来训练卷积神经网络,刚开始模型表现会随着数据规模的增大而提升,但超过五千万张图片后模型的效果提升就不太明显了。
文章《How Training Data Affect the Accuracy and Robustness of Neural Networks for Image Classification》的作者还发现,随着训练数据规模的增加,图像分类的准确度确实会上升;但是,模型的鲁棒性会在数据规模到达一定程度后开始下降。
分类任务中确定训练数据集大小的方法
该方法基于我们所熟知的学习曲线,一般而言,学习曲线图的纵轴为误差,横轴为训练数据集大小。《 Tutorial: Learning Curves for Machine Learning in Python》和《Learning Curve》是很好的参考资料,可以用于进一步了解机器学习中的学习曲线,以及它们是如何随着偏差或方差的增加而变化的。Python 在 scikit-learn 中提供了一种学习曲线函数。
在分类任务中,我们往往会使用学习曲线的一种轻微变体,在该曲线图中,纵轴为分类准确度,横轴为训练数据集大小。训练集规模的确定十分简单:只需针对你的问题,先确定学习曲线的确切形状,然后找到曲线上你预期的分类准确度所对应的训练数据集大小即可。
例如,在文章《Predicting Sample Size Required for Classification Performance》和《How Much Data Is Needed to Train A Medical Image Deep Learning System to Achieve Necessary High Accuracy?》中,作者们将学习曲线的方法应用到了医学领域,并且给出了一个相应的幂函数:
学习曲线公式
其中,y 为分类准确度,x 为训练集,b1,b2 分别为学习率和衰减率。根据问题的不同,参数会有所不同,可以通过非线性回归或加权非线性回归对参数进行估计。
增大训练集是应对不平衡数据集的最好方式?
文章《Precision-Recall Versus Accuracy and the Role of Large Data Sets》对该问题进行了讨论。该文作者提出了一个很有意思的观点:在不平衡的数据集下,准确度并不是一个分类器表现好坏的最佳度量。
原因很简单,对于一个负样本为主的数据集,模型往往通过将大部分样本分类为负样本,以提高准确度。为了更好地衡量模型效果,他们将准确率和召回率(又称敏感性)作为不平衡数据集下度量模型表现的合理标准。
除了上述提到的关于准确度的问题,作者们还指出,对于存在不平衡数据的问题而言,模型的准确率往往对其更加重要。比如一个医院的警报系统而言,高精确率就意味着当警铃响起时,很有可能确实有病人遇到了麻烦。
之后,该文章分别使用较大的非平衡训练集和不平衡学习包(imbalanced-learn, 基于Python scikit-learn)对模型进行了训练,并使用准确率和召回率对训练效果进行了分别的度量。
第一个模型使用了一个包含5万个样本的药物研发数据,并构建了使用不平衡矫正方法的K-近邻模型。第二个模型使用了一个包含大约100万个样本的数据集上,构建了一个简单的K-近邻模型。
其中,不平衡矫正方法包括欠采样、过采样和集成学习。文章作者重复了200次实验,其结论为,当把精确率和召回率作为度量时,没有任何一种不平衡矫正方法比增加更多训练数据的效果更好。
原文链接:
https://towardsdatascience.com/how-do-you-know-you-have-enough-training-data-ad9b1fd679ee
References
[1] The World’s Most Valuable Resource Is No Longer Oil, But Data,https://www.economist.com/leaders/2017/05/06/the-worlds-most-valuable-resource-is-no-longer-oil-but-data May 2017.
[2] Martinez, A. G., No, Data Is Not the New Oil,https://www.wired.com/story/no-data-is-not-the-new-oil/ February 2019.
[3] Haldan, M., How Much Training Data Do You Need?, https://medium.com/@malay.haldar/how-much-training-data-do-you-need-da8ec091e956
[4] Wikipedia, One in Ten Rule, https://en.wikipedia.org/wiki/One_in_ten_rule
[5] Van Smeden, M. et al., Sample Size For Binary Logistic Prediction Models: Beyond Events Per Variable Criteria, Statistical Methods in Medical Research, 2018.
[6] Pete Warden’s Blog, How Many Images Do You Need to Train A Neural Network?, https://petewarden.com/2017/12/14/how-many-images-do-you-need-to-train-a-neural-network/
[7] Sullivan, L., Power and Sample Size Distribution,http://sphweb.bumc.bu.edu/otlt/MPH-Modules/BS/BS704_Power/BS704_Power_print.html
[8] Wikipedia, Vapnik-Chevronenkis Dimension, https://en.wikipedia.org/wiki/Vapnik%E2%80%93Chervonenkis_dimension
[9] Juba, B. and H. S. Le, Precision-Recall Versus Accuracy and the Role of Large Data Sets, Association for the Advancement of Artificial Intelligence, 2018.
[10] Zhu, X. et al., Do we Need More Training Data?https://arxiv.org/abs/1503.01508, March 2015.
[11] Shchutskaya, V., Latest Trends on Computer Vision Market,https://indatalabs.com/blog/data-science/trends-computer-vision-software-market?cli_action=1555888112.716
[12] De Berker, A., Predicting the Performance of Deep Learning Models,https://medium.com/@archydeberker/predicting-the-performance-of-deep-learning-models-9cb50cf0b62a
[13] Sun, C. et al., Revisiting Unreasonable Effectiveness of Data in Deep Learning Era, https://arxiv.org/abs/1707.02968, Aug. 2017.
[14] Hestness, J., Deep Learning Scaling is Predictable, Empirically,https://arxiv.org/pdf/1712.00409.pdf
[15] Joulin, A., Learning Visual Features from Large Weakly Supervised Data, https://arxiv.org/abs/1511.02251, November 2015.
[16] Lei, S. et al., How Training Data Affect the Accuracy and Robustness of Neural Networks for Image Classification, ICLR Conference, 2019.
[17] Tutorial: Learning Curves for Machine Learning in Python, https://www.dataquest.io/blog/learning-curves-machine-learning/
[18] Ng, R., Learning Curve, https://www.ritchieng.com/machinelearning-learning-curve/
[19]Figueroa, R. L., et al., Predicting Sample Size Required for Classification Performance, BMC medical informatics and decision making, 12(1):8, 2012.
[20] Cho, J. et al., How Much Data Is Needed to Train A Medical Image Deep Learning System to Achieve Necessary High Accuracy?, https://arxiv.org/abs/1511.06348, January 2016.
[21] Lemaitre, G., F. Nogueira, and C. K. Aridas, Imbalanced-learn: A Python Toolbox to Tackle the Curse of Imbalanced Datasets in Machine Learning, https://arxiv.org/abs/1609.06570