论文 | matching net 《Matching Networks for One Shot Learning》

一 写在前面

未经允许,不得转载,谢谢~~

这是做小样本图像分类的文章,文章是2016年发在nips上的,但是现在为止还是作为很多one-shot小样本领域的competitor进行比较。

  • 文章出处:NIPS2016
  • 文章链接:https://arxiv.org/abs/1606.04080

二 内容介绍

这里就按照个人的理解简单总结一下文章内容了

2.1 one shot learning 小样本学习

目前的网络依赖大量的标注数据,但是人能够只通过几张图像就快速学习到一个新概念。

由这一点受到启发,引出小样本学习的概念,即测试阶段,对于没有见过的类别只有少数几个标注样本的情况下,如何快速的学习这个概念 ,进行识别。 (可能会有理解或者总结不到位的地方,望见谅~)

2.2 本文主要工作

主要有两个部分的创新之处:

  • modeling level: 提出matching net,利用attention以及memory来获取快速学习新概念的能力;
  • training procedure:在训练时采取与测试一样的原则(只有少量的标准样本)

2.2.1 training procedure

这个地方的创新之处在于不是简单的利用所有的标注数据进行模型的训练,然后再测试阶段进行测试。而是直接在训练的时候就去模仿测试时只有少量标注样本的情况,提出episode的概念。

就是每个episode会包含一个support set(充当训练数据),和一个测试集合batch(充当测试数据)。其中每个support set都是随机生成的,比较有名的就是N-way-K-shot的模式。在所有的类别中随机选择N个类别,然后每个类别又随机选择K个样本作为支持集(k通常小,1-5)。

这样的好处是模拟测试阶段只有少量标注数据的情况,就是测试阶段怎么使用,训练的时候就怎么训练。

其实就是meta-learning的概念。

那我们就知道对于每个episode,网络的优化目标一定是要所有batch中测试样本产生的loss最小。

论文 | matching net 《Matching Networks for One Shot Learning》_第1张图片
  • S 表示Support set,其中包含n*k个标注样本{x_i,y_i}
  • B 表示测试集合Batch
  • 那么现在问题就转换给定S的情况下,假设有一个测试样本x,如何得到x属于y的概率;

2.2.2 model architecture

这里来解决给定S的情况下,假设有一个测试样本x,如何得到x属于y的概率的问题;

用数学表示其实就是:


这里作者给出的计算方法是:


论文 | matching net 《Matching Networks for One Shot Learning》_第2张图片
  • 其中a表示attention机制;
  • 不管a的情况下,其实可以看成是支持向量集S中{x_i,y_i}所有样本的线性结合,只是给他们赋予了不同的权重。
  • 示意图:


    论文 | matching net 《Matching Networks for One Shot Learning》_第3张图片

对于a作者采用的也是最简单的方式,就是用x^样本与x样本之间的特征(文中成为embedding)的cosine距离的softmax值进行计算。


  • c表示cosine距离;
  • f和g其实就如上图所示,表示特征提取器;

整篇文章到这里都挺简单的,没有什么特别复杂的地方,设计的方法也比较自然合理。

但是文章对对于f和g这2个特征提取器的设计确实是下了一番功夫。

g特征提取器

  • 首先一般的都是g(x)这样的函数表示,及输出只与输入x有关;

  • 但是作者觉得这样不够,x的特征除了与x本身有关外,还应该与support set中其他的样本也相关;

  • 所以他将原本无序的集合看成是有序的,然后用双向LSTM进行建模;

  • 最后得到的g表示:




  • 其中g'(x), h(i-1), c(i-1), h(i+1), c(i+1)分别表示原始状态,上一时刻的隐含状态,上一时刻的记忆状态,下一时刻的隐含状态,下一时刻的记忆状态;

h特征提取器

  • 对于每个测试样本,使用注意力LSTM来获取测试样本的特征;
  • 最终得到的h表示:


    论文 | matching net 《Matching Networks for One Shot Learning》_第4张图片
  • 其中rc分别表示读数状态和记忆状态;
  • 最后一步得到的h即为测试样本的特征;

三 写在最后

整体的文章思路挺自然易懂的,但是g和h的特征提取部分确实不是很容易吸收,需要有耐性去看,个人感觉自己的基础功还不是很扎实,对这部分的解读不够透彻。

这里也推荐两篇我觉得写得不错的博客给大家,尤其是【平价数据】One Shot Learning
)这篇真得写的很棒。

参考资料:

  • 【One Shot】《Matching Networks for One Shot Learning》
  • 【平价数据】One Shot Learning

你可能感兴趣的:(论文 | matching net 《Matching Networks for One Shot Learning》)