Why can machines learn?

我们已经探讨了从训练集中学习得到可用于训练集之外数据的经验这件事可行的原因,要想得到有效的学习,需要以下两个方面:

Why can machines learn?_第1张图片

先抛开第二点不谈,这涉及到具体算法。单说第一点,可选假设数量M是有限的,这个条件是不容易满足的,即使是最简单的感知机模型,我们的假设空间包含的假设也是无限的,这样一来,我们将无法保证模型的泛化能力。

Why can machines learn?_第2张图片

很自然的思路就是寻找一个和假设空间H相关的有限值,来代替这个可能无限的M,来保证这个上界对于我们的分析是有效的。

如何做到呢?首先我们考虑一下现在这个宽松的上界是怎么来的。

Why can machines learn?_第3张图片

没错,就是一个简单的Union Bound……也就是说我们用的是多个集合并集的元素个数不大于各个集合元素个数之和……这个上界不宽松才怪……尤其是对于上图中这种overlap很大的情况。实际上,有很多的假设h确实是有很大的overlap,因为有很多h在训练集上的运行结果根本就是完全一样的,而且它们在训练集外的表现可能也差不多,这样一来,我们就可以把假设空间H分为几类来考虑,情况将大大简化。

Why can machines learn?_第4张图片

我们先考虑最简单的感知机二分类模型,当只有一个点的时候,有多少条线可以把这一点进行分类呢?无限条。但如果我们按照分类结果可以把线划分为几类呢?两类!

这就是我们的思路:按照分类结果将假设分类。

现在我们已经把无限的h转化成有限的类,接下来的问题就是,这样分类后的假设种类数够不够少呢?如果类别数依然很大,那么我们依然得不到好的上界,或者说依照这个上界我们需要很大的样本数N才能得到满意的泛化性能。可以优化吗?我们先观察两个样本点和三个样本点的情况:

Why can machines learn?_第5张图片

Why can machines learn?_第6张图片

Why can machines learn?_第7张图片

我们可以看到,对于N个样本点,最多有2 N种分类结果,所以我们假设种类的上限就是2 N,那么我们真的可以达到每一种分类结果吗?不一定,比如上面三点共线的情况,我们就有两种分类结果无法通过一条直线划分出来。当然,如果三点不共线,我们还是可以得到所有可能的结果的。这样看来,似乎取上界的时候我们就只能取到2 N了,随N指数级增长。真的如此吗?我们来看四个点的情况:

Why can machines learn?_第8张图片

我们发现,无论我们怎么摆放4个点,总有两种分类结果是无法得到的(实际上这和感知机无法处理异或的情形一样),这让我们看到了一线曙光,我们似乎可以打破2N这个指数级的边界,转而得到一个关于N的多项式。

Why can machines learn?_第9张图片

我们定义一个生长函数,mH(N)表示对于所有可能的N个样本点,假设空间H中的假设最多可以产生的分类情况数。这个生长函数如何计算呢?或者说我们可以找到它的一个比较紧的上界吗?

Why can machines learn?_第10张图片

之前我们已经看到,对于二维的感知机模型,当样本数量为4的时候,我们就无法产生所有可能的分类了,容易知道,当样本数量大于4时,我们同样也不可能产生所有分类。定义使得样本不能产生所有分类的样本数为break point。接下来我们看一下break point的存在是如何限制了成长函数随N的增长速度。

Why can machines learn?_第11张图片

从上图我们看到,当有3个数据点,且最小的break point为2时,我们在这3个点上最多只能产生4种分类结果,这是因为我们必须保证其中任意两点都不能产生所有的4种分类(因为break point为2)。

Why can machines learn?_第12张图片

我们看到,break point确实限制了生长函数的增长,至于这个限制有多强,我们希望能把它限制到N的多项式的水平。接下来我们证明,确实如此。

Why can machines learn?_第13张图片

首先我们定义一个Bounding Function,表示给定最小的break point k情况下生长函数的最大可能值。

Why can machines learn?_第14张图片

我们的思路是建立递推式,其实这是十分自然的想法,因为我们从之前的例子可以看出,3个点的情形之所以被限制,是因为任意两点不能存在所有的4种分类方式,这实际上就隐含了递推。我们现在看上面这个k=3的情形,此时3个点最多有7种分类情况,假设这7种情况如图,那么我们可以找到对应的4个点的最多分类情况。从图中我们可以看到,左边的橙色部分实际上代表任意两个点都不被shatter的情况,这些情况都可以加一个O或者加一个X扩充成右表的橙色部分,因为左边任意两个点都不被shatter,所以扩充后任意3个点都不被shatter。这样我们就建立起了递推式:

Why can machines learn?_第15张图片

更一般的,我们有:

Why can machines learn?_第16张图片

这种形式的递推式和组合数的递推式完全相同,所以容易证明:

Why can machines learn?_第17张图片

至此我们已经证明了,生长函数关于N是多项式的!

Why can machines learn?_第18张图片

我们想要的是直接用生成函数去代替M,但实际上的公式应该是下面那个,多了一些系数。具体细节就不展开了,总之这个界已经可以处理M无限的情况了,实际上,这个上界已经和M没关系了,这是非常好的结果。这个界正式的名字是:

Why can machines learn?_第19张图片

至此,我们就了解了学习可行的理论基础。至于在训练集上找到一个好的假设h,这就是具体的学习算法要做的事情了。算法这部分毫无疑问是多数机器学习课程的重点,而提到泛化相关理论证明的并不多,所以在这里记录一下,心中有数就好,这样各种算法用起来我心里也踏实一些……

再补充一点,就是过拟合问题,过拟合是现在机器学习面临的核心问题,西瓜书里也给出了过拟合不可能完全避免的粗略论证。从我们的VC上界来考虑过拟合,实际上就是当模型过于复杂的时候,假设空间变大,对于点的各种分类情形就越容易达到,就比如二维感知机无能为力的异或情形,只要加一层隐藏层就搞定了。这样一来我们最小的break point就变大了,从而直接影响了这个上界的最高次数,使得泛化性能下降。

你可能感兴趣的:(Why can machines learn?)