论文阅读(46)OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING

1.论文相关

论文阅读(46)OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING_第1张图片
image.png

2. 摘要

尽管深度神经网络在大数据域中取得了巨大的成功,但它们通常在小样本学习任务中表现不佳,在这些任务中,分类器在看到每个类中很少的示例后必须快速进行归纳。一般认为,基于梯度的高容量分类器优化需要在多个例子上进行多次迭代才能取得良好的效果。在这里,我们提出了一个基于LSTM的元学习模型(LSTM based meta-learner model)来学习精确的优化算法,用于训练另一个学习器神经网络分类器。我们的模型的参数化允许它学习适当的参数更新,特别是对于将要进行一定数量更新的场景,同时还学习学习器(分类器)网络的一般初始化,允许快速收敛训练。我们证明了这种元学习模型与深度度量学习技术在小样本学习中具有竞争力。

2.2 任务描述(TASK DESCRIPTION)

我们首先详细介绍我们使用的元学习公式。在典型的机器学习环境中,我们对一个数据集感兴趣,通常会进行分割,以便在训练集上优化参数,并在测试集上评估其泛化性。然而,在元学习中,我们处理的是包含多个正则数据集的元集,其中每个都分割成一个和。

我们考虑k-shot, N-class分类任务,其中对于每个数据集,训练集由每个类的个标记示例组成,这意味着训练集由个示例组成,测试集具有一组用于评估的示例。我们注意到,之前的研究(Vinyals等人,2016年)使用了episode术语来描述由训练和测试集组成的每个数据集。

因此,在元学习中,我们有不同的元集(meta-sets)用于元训练、元验证和元测试(分别为、和)。在中,我们感兴趣的是训练一个学习过程(元学习器),该学习过程可以将其训练集之一作为输入,并生成一个分类器(学习者),该分类器在其相应的测试集上获得较高的平均分类性能。利用可以对元学习器进行超参数选择,并对其在上的泛化性能进行评价。

为了使该公式对应于小样本学习设置,数据集中的每个训练集将包含少量的标记示例(我们考虑k=1或k=5),这些示例必须用于在相应的测试集上推广到良好的性能。这个公式的一个例子如图1所示。

3 模型

我们现在开始描述我们提出的元学习模型。

3.1 模型说明(MODEL DESCRIPTION)

考虑一个单独的数据集,或episode,。假设我们有一个学习器神经网络分类器,它有我们想要在上训练的参数。用于训练深层神经网络的标准优化算法是梯度下降的一种变体,它使用的更新形式如下:

image.png
论文阅读(46)OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING_第2张图片
image.png

其中,是学习器步更新后的参数,是t时的学习率,是学习器为第时刻更新优化的损失,是损失相对于参数的梯度,是学习器的更新参数。

我们在此利用的关键观察结果是,此更新类似于LSTM (Hochreiter & Schmidhuber, 1997)中的单元状态更新。

image.png

因此,我们建议训练一个元学习器(meta-learner) LSTM来学习更新规则以训练神经网络。我们将 LSTM的细胞状态(cell state)设置为学习器的参数,或,并将候选细胞状态(candidate cell state)设置为?,给出梯度信息对优化有多大价值。我们定义了参数和形式,以便元学习器可以通过更新过程来确定最佳值。

让我们从开始,对应于更新的学习速率。我们让:

image.png

这意味着学习率是当前参数值、目前梯度、目前损失和先前学习率的函数。有了这些信息,元学习器应该能够很好地控制学习速度,以便在避免分歧(avoiding divergence)的同时快速地训练学习器。

至于,似乎最佳选择不是常数1。直观地说,如果学习器目前处于一个糟糕的局部最优状态,并且需要一个大的变化来逃避,那么缩小学习者的参数并忘记其先前的部分值是合理的。这将对应于损失很高但梯度接近于零的情况。因此,对于遗忘门(forget gate)的一个建议是将其作为该信息的函数,以及遗忘门的先前值:

image.png

另外,请注意,我们还可以学习LSTM的单元状态的初始值,将其作为元学习器的参数。这对应于分类器的初始权重(元学习器正在训练)。学习这个初始值可以让元学习器确定学习器的最佳初始权重,以便训练从一个有利的起点开始,允许优化快速进行。最后,请注意,尽管元学习器的更新规则与LSTM的单元状态更新匹配,但元学习器也与GRU(cho et al.,2014)隐藏状态更新具有相似性,除了忘记门和输入门没有绑定到和(forget and input gates aren’t tied to sum to one)。

3.2 参数共享与预处理(PARAMETER SHARING & PREPROCESSING)

因为我们希望元学习器为由上万个参数组成的深层神经网络生成更新,为了防止元学习器参数爆炸,我们需要采用某种参数共享。如Andrychowicz等人所述(2016),我们在学习器梯度的坐标上共享参数。这意味着每个坐标都有自己的隐藏值和单元状态值,但所有坐标的LSTM参数都相同。这使得我们可以使用一个紧凑的LSTM模型,并且还具有一个很好的特性,即对每个坐标使用相同的更新规则,但是在优化过程中依赖于每个坐标各自的历史。我们可以很容易地实现参数共享,输入是一批梯度坐标和每个维度的损失输入。

由于梯度和损失的不同坐标可能有很大的不同,我们需要小心地规范化这些值,以便元学习器能够在训练中正确地使用它们。因此,我们还发现了andrychowicz等人的预处理方法适用于梯度的维度和每个时间步的损失时,效果良好:

image.png

这个预处理调整了梯度和损耗的比例,同时也分离了它们的大小和符号的信息(后者主要用于梯度)。我们发现上述公式中的建议值p=10在我们的实验中效果良好。

3.3 训练(TRAINING)

现在的问题是,我们如何训练LSTM元学习器模型,使其在小样本学习任务中有效?如Vinyals等人所观察到的,为了在这项任务中表现出色,关键是训练条件要与测试时相匹配。在元学习的评价过程中,对于每个数据集(episode),,一个好的元学习模型,在给定一系列的学习器梯度和训练集上的损失的情况下,会对分类器提出一系列的更新,使其在测试集上取得良好的性能。

因此,为了匹配测试时间条件,当考虑每个数据集时,我们使用的训练目标是在测试集上生成的分类器的损失。在对训练集中的示例进行迭代时,LSTM元学习器在每一个时间步从学习器(分类器)接收并提出新的参数集。这个过程重复个步骤,然后在测试集上评估分类器及其最终参数,以产生损失,然后用于训练元学习器。训练算法如算法1所述,相应的计算图如图2所示。

论文阅读(46)OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING_第3张图片
image.png
论文阅读(46)OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING_第4张图片
image.png
3.3.1 梯度独立假设(GRADIENT INDEPENDENCE ASSUMPTION)

注意,我们的公式意味着学习器的损失和梯度取决于元学习者的参数。元学习器参数的梯度通常应该考虑到这种依赖性。然而,正如andrychowicz等人所讨论的那样,这使得元学习器梯度的计算复杂化。因此,遵循andrychowicz等人的观点,我们做了一个简化的假设,即这些对梯度的贡献并不重要,可以忽略,这使得我们可以避免使用二阶导数,这是一个相当昂贵的操作。尽管有这样简单的假设,我们仍然能够有效地训练元学习器。

3.3.2 元学习器LSTM的初始化(INITIALIZATION OF META-LEARNER LSTM)

在训练LSTM时,建议使用小的随机权重初始化LSTM,并将遗忘门偏差设置为大值,以便将遗忘门初始化为接近1,从而启用梯度流(enabling gradient flow)(Zaremba,2015)。除了遗忘门偏差设置,我们发现我们需要将输入门偏差初始化为小,以便元学习器LSTM使用的输入门值(以及学习率)开始变小。通过这种组合初始化,元学习器以较小的学习率开始接近正常梯度下降,这有助于训练的初始稳定性。

3.4 批量规范化(BATCH NORMALIZATION)

批处理规范化(ioffe&szegedy,2015)是最近提出的一种通过减少学习器隐藏层内的内部协变量变化(internal covariate shift)来稳定和加速深层神经网络学习的方法。这种减少是通过规范化每一层的预激活(pre-activation),减去平均值,除以标准差来实现的。在训练期间,平均值和标准偏差是使用当前训练批次估计的,而在评估期间,则使用在训练集上计算的两个统计值的运行平均值。在元学习环境中,我们需要注意学习器网络的批量规范化,因为我们不希望在元测试期间收集平均值和标准偏差统计数据,从而允许不同数据集(集)之间的信息泄漏。防止这个问题的一个简单方法是在元测试阶段根本不收集统计数据,而只使用元训练中的平均值。然而,这对性能有很坏的影响,因为我们已经改变了元训练和元测试的条件,导致元学习器学习一种依赖于批量统计的优化方法,而现在它在元测试时还没有这种方法。为了使这两个阶段尽可能地相似,我们发现一个更好的策略是在期间收集每个数据集的统计数据,然后在考虑下一个数据集时删除正在运行的统计数据。因此,在元训练期间,我们对训练集和测试集使用批统计,而在元测试期间,我们对训练集使用批统计(并计算我们的运行平均值),但在测试期间使用运行平均值。这不会导致不同数据集之间的任何信息泄漏,但也允许元学习器在训练和测试之间匹配的条件下接受训练。最后,由于我们进行很少的训练步骤,所以我们计算了运行平均数,以便对后面的值给予更高的偏好。

4 相关工作(RELATED WORK)

虽然这项工作在广泛的文献中一般属于迁移学习,我们在这里的重点是定位它相对于元学习和小样本学习的先前工作。

4.1 元学习(META-LEARNING)

元学习有着悠久的历史,但随着许多人提倡它是未来实现人类智能水平的关键,元学习近年来变得越来越突出(Lake等人,2016)。在两个层次上学习的能力(在每个任务中学习,同时积累关于任务之间相似性和差异的知识)被认为是提高人工智能的关键。先前的研究在元学习环境中使用了多种技术。

schmidhuber(1992;1993)探索了使用网络来学习如何在输入的许多计算步骤上修改自己的权重。权值的更新以参数形式定义,使得预测和权值变化过程是端到端可微的。Bengio等人的工作。(1990;1995)和Bengio(1993)认为神经网络的学习更新规则在生物学上是合理的。此属性是通过允许更新的参数形式在每个隐藏单元上仅具有作为输入的本地信息来确定权重更改来实现的。不同的优化方法,如遗传规划或模拟退火,被用来训练学习规则。

在Santoro等人(2016),训练记忆增强神经网络,学习如何存储和检索用于每个分类任务的记忆。andrychowicz等人的工作。(2016)使用LSTM训练神经网络;然而,他们对学习通用优化算法训练神经网络进行大规模分类感兴趣,而我们对少数镜头学习问题感兴趣。这项工作也建立在Hochreiter等人的基础上。(2001)和Bosc,他们都使用LSTM训练多层感知器学习二进制分类和时间序列预测任务。另一个相关的方法是Bertineto等人的工作。(2016),他训练元学习者将训练示例映射到神经网络的权重,然后用神经网络对该类未来示例进行分类;然而,与我们的方法不同,分类器网络是直接生成的,而不是经过多个训练步骤后进行微调。我们的工作也与Maclaurin等人的工作相似。(2015),他通过梯度步骤链的反向传播,利用动量调整梯度下降的超参数,以优化验证性能。

4.2 小样本学习

小样本学习的最佳方法主要是度量学习方法。深度暹罗网络(Koch,2015)根据某种距离度量,训练卷积网络以嵌入示例,从而使同一类中的项接近,而不同类中的项远离。匹配网络(Vinyals等人,2016年)通过定义一个可微的最近邻损失(涉及卷积网络产生的嵌入的余弦相似性),改进了这一思想,使训练和测试条件相匹配。

5 评价

在这一部分中,我们描述了实验结果,检验了模型的特性,并将我们的方法与不同的方法进行了性能比较。和Vinyals等人一样,我们考虑的是k-shot, N-class分类设置,一个元学习器在许多相关但很小的训练样本集上训练,N个类,每一个类有k个样本。我们首先将数据中所有类的列表分成不相交的集合,并将它们分配给元训练、元验证和元测试的每个元集合。为了生成一个k-shot, N-class任务数据集的每个实例,我们这样做:我们首先从与我们考虑的元集对应的类列表中采样类。然后,我们从这些类中的每个类中抽取个示例。这些例子一起构成了训练集。然后,对其余示例的附加固定量进行采样以产生测试集。我们通常在测试集中每个类有15个示例。在训练元学习器时,我们通过重复采样这些数据集(episodes)来迭代。然而,对于元验证和元测试,我们生成固定数量的这些数据集来评估每个方法。我们产生了足够多的数据集,以确保平均精度的置信区间很小。

对于学习器,我们使用包含4个卷积层的简单CNN,每个卷积层是一个3×3卷积,带32个滤波器,然后是批处理规范化、ReLU非线性,最后是一个2×2 最大池化。然后,网络有一个最终的线性层,后面是所考虑的类数的softmax。损失函数L是学习器分配给正确类的平均负对数概率(average negative log-probability)。对于元学习器,我们使用两层LSTM,其中第一层是普通LSTM,第二层是我们改进的LSTM元学习器。梯度和损失被预处理并输入到第一层LSTM中,并且第二层LSTM也使用规则梯度坐标来实现(1)所示的状态更新规则。在每一个时间步,学习器的损失和梯度是在一个由整个训练集组成的批上计算的,因为我们考虑的训练集只有5或25个例子。我们使用0.001的学习率与ADAM一起训练LSTM,使用0.25的值进行梯度剪裁(gradient clipping)。

5.1 实验结果

Mini-ImageNet数据集是vinyals等人提出的,作为一个基准,提供ImageNet图像的复杂性的挑战,而不需要在整个ImageNet数据集上运行所需的资源和基础结构。因为在Vinyals等人(2016)没有发布,我们通过从ImageNet中随机选择100个类并选择每个类的600个示例来创建自己的Mini-ImageNet数据集版本。我们分别使用64、16和20个类进行训练、验证和测试。我们考虑了5个类的1-shot和5-shot分类。我们在每个测试集中使用每个类15个示例进行评估。我们对比了两个基线和最近的度量学习技术,匹配网络(Vinyals等人,2016),后者在小样本学习中取得了最先进的结果。结果见表1。

论文阅读(46)OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING_第5张图片
image.png

我们使用的第一个基线是最近邻基线(nearest-neighbor baseline,Baseline-nearest-neighbor),在这个基线中,我们首先训练一个网络,以便在原始元训练集中的所有类之间联合进行分类。在元测试时,对于每个数据集,我们使用我们的训练网络将所有项目嵌入到训练集中,然后在嵌入的训练示例之间使用最近邻匹配对每个测试示例进行分类。我们使用的第二个基线(Baseline-finetune)表示元学习器模型的一个粗略版本。在第一个基线中,我们首先训练一个网络,以便在元训练集中的所有类之间联合分类。然后,我们使用元验证集搜索SGD超参数(SGD hyperparameters),其中每个训练集用于在测试集上评估之前微调预训练的网络。我们使用固定数量的更新对这些更新过程中使用的学习率和学习率衰减进行微调和搜索。

对于匹配网络,我们实现了自己的基本版本和完全条件嵌入(fully-conditional embedding,FCE)版本。在基本版本中,卷积网络被训练来学习在训练集和测试集中样本的独立嵌入。在FCE版本中,使用双向LSTM(bidirectional-LSTM)学习训练集的嵌入,使得每个训练示例的嵌入也是所有其他训练示例的函数。此外,还使用了注意力LSTM,使得测试示例嵌入也是训练集所有嵌入的函数。我们不考虑像Vinyals等人提到在元测试期间使用训练集微调网络以提高性能。但请注意,我们的元学习器也可以使用这些数据进行微调。注意保持与Vinyals等人的一致,我们的基线和匹配网络卷积网络有4层,每个层有64个滤波器。我们还为匹配网络中的每个卷积块添加了dropout以防止过度拟合。

对于我们的元学习器,我们为1-shot和5-shot任务训练不同的模型,分别进行12次和5次更新。我们注意到,如果元学习器在元测试期间被显式地训练为执行设置数量的更新,那么每个任务都会获得更好的性能。

我们得到的结果比所讨论的基线要好得多,并且与匹配网络竞争。对于5-shot,我们能够比匹配网络做得更好,而对于1-shot,我们性能的置信区间与匹配网络的置信区间相交。我们再次注意到,这些数字与Vinyals等人提供的数字不匹配。仅仅因为我们创建了我们的数据集版本并实现了我们自己的模型版本。值得注意的是,微调后的基线比最近邻基线差。因为我们没有对分类器进行正则化,很少更新微调模型,特别是在单样本情况下。这种过度拟合的倾向说明了元训练的好处,正如在元学习LSTM中所做的那样,分类器的初始化是端到端的。

5.2 元学习器的可视化(VISUALIZATION OF META-LEARNER)

我们还将元学习器学习的优化策略可视化,如图3所示。我们可以在每个更新步骤查看等式2中的和门值(gate values),以了解元学习器如何在训练期间更新学习器。在不同的数据集上训练时,我们将门值可视化,以观察训练集之间是否存在差异。我们考虑1-shot 和 5-shot分类设置,元学习器分别进行10次和5次更新。对于两个任务的遗忘门值,元学习器似乎采用了一种简单的权重衰减策略,这种策略似乎在不同的层次上是一致的。在收集元学习器的策略时,输入门值很难解释。然而,在不同的数据集之间似乎有很多可变性,这表明元学习器不仅仅是学习一个固定的优化策略。此外,这两个任务之间似乎存在差异,这表明元学习器在处理不同设置的条件时采用了不同的方法。

论文阅读(46)OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING_第6张图片
image.png

6 结论

我们描述了一个基于LSTM的元学习模型,它的灵感来自于梯度下降优化算法提出的参数更新。我们的LSTM元学习器使用它的状态来表示分类器参数的学习更新。它不仅可以发现学习器参数的良好初始化,还可以发现一种成功的机制,用于将学习器参数更新到给定的小训练集,以完成一些新的分类任务。我们的实验表明,我们的方法优于自然基线,并且在小样本学习下,在度量学习方面有竞争力。

在这项工作中,我们的研究集中在小样本和小类别设置。然而,对元学习器的训练将更有价值,因为元学习器能够在全方位的环境中表现良好,例如,对于少数或大量的训练示例和少数或大量可能的类。因此,我们今后的工作将考虑朝着这一更具挑战性的场景迈进。

参考资料

[1]

代码

[1] twitter/meta-learning-lstm

你可能感兴趣的:(论文阅读(46)OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING)