[MTL(CVPR2019)]论文阅读笔记

Meta-Transfer Learning for Few-Shot Learning 论文地址 代码

写在前面

这是cvpr19年又一篇用meta-learning 做 few-shot learning的文章, 跟MAML不同的地方, 就是它不做全部参数的finetune了, 需要更新的参数变少, 使得网络不容易过拟合当前的task, 也使得网路收敛的快, 同时还使用了hard sample mining的方式来提高网络的精度.

Motivation

MAML这种meta-learning 做 few-shot 的方法存在两个主要的问题:

  • 这类方法一般都需要很多相似的任务来实现meta-train;
  • 之前都是用浅的网络来做few-shot这种task, 因为网络一深容易过拟合, 所以很难去使用一些比较厉害的深网络.

Contribution

  • 这篇文章提出了一个meta-transfer learning (MTL) 方法来结合迁移学习和meta-learning , 使得预训练的深度网络迁移到few-shot 的任务上, 网路收敛快并且避免了过拟合;
  • 提出了一个Hard Task (HT) meta-batch 训练策略, 使得网络在艰难样本下性能更好.

Algorithm

总的流程图如下:
[MTL(CVPR2019)]论文阅读笔记_第1张图片

1. DNN training on large-scale data

这边网络的预训练就是与一般的分类网络训练方式一样, 并不是用meta-train的方式来训的, 就比如如果原始的数据集中有64个类, 则训练出来的分类器也是输出64个分类score的. 这篇文章把网络分成了特征提取器 Θ \Theta Θ 和base-learner θ \theta θ (其实就是分类器). 网络训练的时候随机初始化 Θ \Theta Θ θ \theta θ , 其更新方式就是一般的梯度下降的方式:
[ Θ ; θ ] = : [ Θ ; θ ] − α ∇ L D ( [ Θ ; θ ] ) , [\Theta; \theta] = : [\Theta; \theta]-\alpha \nabla L_D([\Theta; \theta]), [Θ;θ]=:[Θ;θ]αLD([Θ;θ]),
L D ( [ Θ ; θ ] ) = 1 ∣ D ∣ ∑ ( x , y ) ∈ D l ( f [ Θ ; θ ] ( x ) , y ) L_D([\Theta; \theta]) = \frac{1}{|D|}\sum_{(x,y) \in D} l(f_{[\Theta; \theta]}(x),y) LD([Θ;θ])=D1(x,y)Dl(f[Θ;θ](x),y)
这里的 l l l 指的是一些距离度量函数, 可以是交叉熵损失, α \alpha α 是学习率. 当训练完后, θ \theta θ 会被砍掉, 然后接新的分类器, 因为few-shot 的任务sample的类别数和总的训练数据集肯定是不一样的.

2. Meta-transfer learning (MTL)

这块介绍了他们提出的scaling 和 Shifting (SS) Φ S { 1 , 2 } \Phi_{S_{\{1,2\}}} ΦS{1,2} 模块, 就是他们在迁移的时候不是直接更新网络参数, 而是在原始的weight和bias上做了一些操作, 这样不仅减少了参数的学习, 同时也保留了预训练时学到的general的信息不被破坏, 减少了过拟合的概率. 细节如下:

对于一个task T = { T ( t r ) , T ( t e ) } T = \{T^{(tr)}, T^{(te)} \} T={T(tr),T(te)}, T ( t r ) T^{(tr)} T(tr)的损失就是用来更新当前的base-learner (classifier) θ ′ \theta' θ :
θ ′ ← θ − β ∇ θ L T ( t r ) ( [ Θ ; θ ] , Φ S { 1 , 2 } ) \theta' \leftarrow \theta- \beta \nabla_{\theta} L_{T^{(tr)}}([\Theta; \theta] , \Phi_{S_{\{1,2\}}}) θθβθLT(tr)([Θ;θ],ΦS{1,2})
这里的 θ \theta θ 和上一节的 θ \theta θ 不一样, 这里的 θ \theta θ 是指只有某几个类的分类器参数; 然后跟上一节不一样的地方就是, 这里不更新 Θ \Theta Θ, Θ \Theta Θ 作为特征提取器, 在预训练完之后就一直不会变.

然后 T ( t e ) T^{(te)} T(te) 的损失是用来更新 Φ S { 1 , 2 } \Phi_{S_{\{1,2\}}} ΦS{1,2} , 这里 Φ S 1 \Phi_{S_1} ΦS1初始化为全1, Φ S 2 \Phi_{S_2} ΦS2 初始化为0. 更新过程如下:
Φ S i = : Φ S i − γ ∇ Φ S i L T ( t e ) ( [ Θ ; θ ′ ] , Φ S { 1 , 2 } ) \Phi_{S_i} = : \Phi_{S_i} - \gamma \nabla_{ \Phi_{S_i}} L_{T^{(te)}}([\Theta; \theta'] , \Phi_{S_{\{1,2\}}}) ΦSi=:ΦSiγΦSiLT(te)([Θ;θ],ΦS{1,2})
θ = : θ − γ ∇ Φ S i L T ( t e ) ( [ Θ ; θ ′ ] , Φ S { 1 , 2 } ) \theta = :\theta - \gamma \nabla_{ \Phi_{S_i}} L_{T^{(te)}}([\Theta; \theta'] , \Phi_{S_{\{1,2\}}}) θ=:θγΦSiLT(te)([Θ;θ],ΦS{1,2})

下面介绍如何将 Φ S { 1 , 2 } \Phi_{S_{\{1,2\}}} ΦS{1,2} 应用到网络中, Θ \Theta Θ 中的所有参数都用 W , b W,b W,b表示, 所以对于一个输入 X X X, 经过SS后提取的特征:
S S ( X ; W , b ; Φ S { 1 , 2 } ) = ( W ⊙ Φ S 1 ) X + ( b + Φ S 2 ) SS(X; W, b;\Phi_{S_{\{1,2\}}}) = (W \odot \Phi_{S_1})X+(b+\Phi_{S_2}) SS(X;W,b;ΦS{1,2})=(WΦS1)X+(b+ΦS2)
如下图:
[MTL(CVPR2019)]论文阅读笔记_第2张图片
SS有以下三个优点:

  1. 利用了深度的DNN提供了一个strong的初始化, 使得MTL可以很快收敛;
  2. 没有改变DNN的权重, 避免破坏原始网络学到的general的信息;
  3. SS是很轻量的, 减少了过拟合的概率.

3 Hard task (HT) meta-batch

这边就是对于所有的task , 先用当前的网络去测试一下, 将每个task中分类acc最低的类别记录下来, 在所有的train完之后, 再用这些类别的sample组合成难的task来训练, 这是参考了课程学习的概念, 使得网络在逐步变难的样本中逐步增加网络的分类性能.

本文的算法流程可以从以下两张图和直观的看出:

[MTL(CVPR2019)]论文阅读笔记_第3张图片
[MTL(CVPR2019)]论文阅读笔记_第4张图片

这里我比较有疑问的地方就是 θ ′ \theta' θ θ \theta θ 之间的关系, 若是 θ ′ \theta' θ 的更新都是根据 θ \theta θ, 那算法2中的4, 有什么意义吗, 若是 θ \theta θ 不变, 这边迭代就没有意义了. 可能需要看代码才知道了.

你可能感兴趣的:([MTL(CVPR2019)]论文阅读笔记)