元学习(meta-learning)是目前广泛使用的处理小样本学习问题的方法,它的目的是在学习不同任务的过程中积累经验,从而使得模型能够快速适应新任务。比如在MAML(Model-Agnostic Meta- Learning)中,通过搜寻最优初始化状态,使得base-learner能够快速适应新任务。但是这一类元学习方法具有两个缺点:
本文提出了一种新的元学习方法,称为meta-transfer learning(MTL),当仅使用少量带有标记的数据时,它可以帮助深度神经网络快速收敛,并且降低过拟合发生的概率,
通过大规模数据集训练得到的DNN权值提供了一个很好的初始化状态,可以确保MTL在处理小样本任务时能够快速收敛。而在DNN神经元上进行的轻量级操作也使得模型的参数更少,降低了过拟合的可能。除此之外,这些操作保持DNN的权值不被改变,从而避免了当适应新任务时,模型会遗忘通用模式这种情况的发生。
本文还提出了一种新的学习策略,称为hard task(HT)meta-batch,以往的meta-batch包含的是一些随机任务,而HT meta-batch根据之前在训练时出现的具有较低验证准确度的失败任务,对hard task进行重新采样。
本文的贡献如下:
这一阶段类似于目标识别中的预训练阶段,这里是在小样本学习benchmark的现成数据集上进行预训练。对于一个确切的小样本数据集,将会融合所有类的数据 D D D以进行预训练。比如在miniImageNet中, D D D的训练集中共有64个类,每个类包含600个样本,那么将利用所有这些数据进行预训练,得到一个64-class分类器。
首先随机初始化一个特征提取器 Θ \Theta Θ和分类器 θ \theta θ,通过梯度下降对它们进行优化:
其中 L L L是交叉熵损失:
这一阶段主要通过学习得到一个特征提取器 Θ \Theta Θ,在后续的meta-training和meta-test阶段, Θ \Theta Θ将会被冻结,而这一阶段得到的分类器 θ \theta θ将会被去掉。
MTL通过HT meta-batch训练来对元操作(meta operation)SS进行优化,将SS操作分别定义为 Φ S 1 \Phi_{S_1} ΦS1和 Φ S 2 \Phi_{S_2} ΦS2,给定任务 T T T, T ( t r ) T^{(tr)} T(tr)是训练数据,使用 T ( t r ) T^{(tr)} T(tr)的损失来优化当前的base-learner(分类器) θ ′ \theta^{'} θ′,也就是对 θ \theta θ进行更新:
与式(1)不同的是,这里并没有更新 Θ \Theta Θ。注意这里的 θ \theta θ与式(1)中的 θ \theta θ是不同的,在式(1)中 θ \theta θ处理的是某个数据集中的所有类,而这里的 θ \theta θ只关注少量的几个类,从而在小样本设置的情况下进行分类。 θ ′ \theta^{'} θ′是一个临时的分类器,它只关注当前的任务。
Φ S 1 \Phi_{S_1} ΦS1由1进行初始化, Φ S 2 \Phi_{S_2} ΦS2由0进行初始化,然后,用 T ( t e ) T^{(te)} T(te)的损失对它们进行优化, T ( t e ) T^{(te)} T(te)是测试数据:
然后用与式(4)中相同的学习率 γ \gamma γ对 θ \theta θ进行更新:
然后说一下如何将 Φ S { 1 , 2 } \Phi_{S_{\lbrace 1,2 \rbrace}} ΦS{1,2}应用到 Θ \Theta Θ的神经元上。给定 Θ \Theta Θ,它的第 l l l层包含 K K K个神经元,也就是 K K K个参数对儿,定义为 { ( W i , k , b i , k ) } \lbrace (W_{i,k},b_{i,k})\rbrace {(Wi,k,bi,k)},分别表示权值和偏差。假设输入是 X X X,那么将 { Φ S { 1 , 2 } } \lbrace \Phi_{S_{\lbrace 1,2 \rbrace}} \rbrace {ΦS{1,2}}应用到 ( W , b ) (W,b) (W,b)上就是:
下图说明了分别通过SS和FT进行更新的不同之处:
在以往的元训练中,meta-batch包含的是随机采样的任务,也就是说任务的难度也是随机的。本文在元训练中有意挑选出每个任务中的失败案例(failure case),并将其数据重新组合为难度较大任务,迫使meta-learner"在困难中成长"。
每个任务 T T T都包含 T ( t r ) T^{(tr)} T(tr)和 T t e T^{{te}} Tte,分别用于base-leaning和test,base-learner由 T ( t r ) T^{(tr)} T(tr)的损失进行优化(在多个epoch中进行迭代),然后SS由 T ( t e ) T^{(te)} T(te)的参数进行优化(只进行一次)。对于 T ( t e ) T^{(te)} T(te),可以得到其中 M M M个类的识别精度,然后根据最低的精度 A c c m Acc_{m} Accm判断当前任务中最困难的类class- m m m(failure class)。在从当前的meta-batch { T 1 − k } \lbrace T_{1-k} \rbrace {T1−k}的所有 k k k个任务中获得所有failure class之后,从这些数据中重新采样任务。也就是说,假设 p ( T ∣ { m } ) p (T|\lbrace m\rbrace) p(T∣{m})是任务分布,那么采样harder task T t a s k ∈ p ( T ∣ { m } ) T^{task} \in p (T|\lbrace m\rbrace) Ttask∈p(T∣{m}),采样的具体细节如下:
Algorithm 1总结了DNN的训练和MTL这两个阶段,其中的failure class由Algorithm 2返回。Algorithm 2说明了在单个任务上的学习过程,包括episode training和episode test
本文在MAML的基础上,使用了一个较深的预训练DNN模型,为了更好地发挥DNN的效果,在固定DNN每层参数不变的情况下,为每层的权值和偏差分别设置了可学习的scaling和shifting,这样可以降低参数数量,避免过拟合。除此之外,为了增强模型的泛化能力和鲁棒性,本文使用HT meta-batch学习策略。