如何理解few-shot learning中的n-way k-shot?

原文:https://www.zhihu.com/question/363200569/answer/2626785660?utm_id=0

作者:胖迪王

链接:https://www.zhihu.com/question/363200569/answer/2626785660

来源:知乎

著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

论文:[1606.04080] Matching Networks for One Shot Learning (arxiv.org)

时间:NIPS 2016

最近在读《Matching Networks for One Shot Learning》这篇文章,里面好多内容有些疑问,以下参考博客并结合自己的理解,可能有些地方存在问题,希望大家多多指正。每天学一点知识,你将变得更优秀哒。

N-way-K-shot任务

N-way-K-shot任务就是将任务 τ\tau 划分为N个类别(way),每个类别的支持集(support set)包含K个样本,任务 τ\tau 中剩余的样本作为该任务的验证集(query set).其中每个"任务"包含支持集(support set)和验证集(query set)

如何理解few-shot learning中的n-way k-shot?_第1张图片
如何理解few-shot learning中的n-way k-shot?_第2张图片

匹配网络

目的:提供一个网络框架,能将少量数据集和未标记的实例映射到所属标签,避免通过微调已训练好的模型来适应新类

创新点:结合度量学习和记忆增强神经网络的新型神经网络结构----匹配网络

对于少量数据集而言,模型在拟合数据时,可能会产生过拟合问题 ,这个问题可使用正则化和数据增强方式来缓和。但这些都是治标不治本。训练样本需要被参数模型通过梯度下降对参数进行更新,使得学习速率比较缓慢。对于许多非参数模型能快速同化新的实例并且不会遭受遗忘。作者结合参数模型和非参数模型来获取新的实例,提高模型的泛化能力。作者从注意力的序列到序列(seq2seq)、记忆网络以及指针网络中获得灵感。提出了匹配网络,它利用注意力机制和记忆机制加速学习,实现在少量数据的条件下对无标签的实例进行标签预测。

符合定义:支持集 S=(xi,yi)i=1kS=(x_i,y_i)_{i=1}^k ,预测类别的图像为 x^\hat{x}

算法理论:

1.基于余弦距离的注意力机制

通过余弦距离计算训练实例 xix_i 与测试实例 x^\hat{x} 之间的相似度,通过softmax对相似度进行归一化后得到测试实例x^\hat{x}在训练样本xix_i上的注意力分布 a(x^,xi)a(\hat{x},x_i)

a(x^,xi)=ec(f(x^),g(xi))∑j=1kec(f(x^),g(xi))a(\hat{x},x_i)=\frac{e^{c(f(\hat{x}),g(x_i))}}{\sum _{j=1}^ke^{c(f(\hat{x}),g(x_i))}}

其中,嵌入函数g和f的作用是将xix_i和x^\hat{x}嵌入(embadding)到空间中(特征提取)

模型的输出 yi^\hat{y_i} :

P(y^|x^,S)=∑i=1ka(x^,xi)yiP(\hat{y}|\hat{x},S)=\sum_{i=1}^{k}a(\hat{x},x_i)y_i

2.Full Context Embeddings

如何理解few-shot learning中的n-way k-shot?_第3张图片
如何理解few-shot learning中的n-way k-shot?_第4张图片
如何理解few-shot learning中的n-way k-shot?_第5张图片
如何理解few-shot learning中的n-way k-shot?_第6张图片

1.训练集嵌入函数g

首先,通过一个普通的网络(VGG等)对支持集中训练样本的每个样本进行原始特征提取,记为 g′(xi)g'(x_i)

然后,采用一个双向LSTM模型,为每个训练实例xix_i设置四个状态量,分别是

如何理解few-shot learning中的n-way k-shot?_第7张图片
如何理解few-shot learning中的n-way k-shot?_第8张图片

前向隐状态 →hi\underset{h_i}{\rightarrow} 和 →ci\underset{c_i}{\rightarrow} ,由前一个训练实例 xi−1x_{i-1} 的隐状态和记忆细胞通过LSTM模型确定:

→hi,→ci=LSTM(g′(xi),→hi−1,→ci−1)\underset{h_i}{\rightarrow},\underset{c_i}{\rightarrow}=LSTM(g'(x_i),\underset{h_{i-1}}{\rightarrow},\underset{c_{i-1}}{\rightarrow})

前向隐状态 ←hi\underset{h_i}{\leftarrow} 和 ←ci\underset{c_i}{\leftarrow} ,由前一个训练实例 xi+1x_{i+1} 的隐状态和记忆细胞通过LSTM模型确定:

←hi,←ci=LSTM(g′(xi),←hi+1,←ci+1)\underset{h_i}{\leftarrow},\underset{c_i}{\leftarrow}=LSTM(g'(x_i),\underset{h_{i+1}}{\leftarrow},\underset{c_{i+1}}{\leftarrow})

支持集的特征由前后隐状态和原始特征共同决定:

如何理解few-shot learning中的n-way k-shot?_第9张图片
如何理解few-shot learning中的n-way k-shot?_第10张图片

函数g特征提取时不仅考虑原始特征 g′(xi)g'(x_i) 还考虑该训练样本和支持集中的其他样本有某种相关性

2.测试集嵌入函数f

如何理解few-shot learning中的n-way k-shot?_第11张图片
如何理解few-shot learning中的n-way k-shot?_第12张图片
如何理解few-shot learning中的n-way k-shot?_第13张图片
如何理解few-shot learning中的n-way k-shot?_第14张图片

首先,通过一个普通的网络对测试集的单个样本进行特征提取,记为 f′(xi^)f'(\hat{x_i})

如何理解few-shot learning中的n-way k-shot?_第15张图片
如何理解few-shot learning中的n-way k-shot?_第16张图片

2. 获得当前时刻测试样本的隐状态

3. 获得训练集的特征的加权和记为read-out

如何理解few-shot learning中的n-way k-shot?_第17张图片
如何理解few-shot learning中的n-way k-shot?_第18张图片

4. 将 rkr_k 作为测试样本的特征

隐状态h决定了把注意力应该放在哪一些支持集的样本上。

以上介绍的两个嵌入函数g和f是论文中提到的Full Context Embeddings的两个部分

训练策略

对一个任务T和带标签的数据L,每个任务中最多包含5类,每一类最多含有5张图片。

训练流程:

如何理解few-shot learning中的n-way k-shot?_第19张图片
如何理解few-shot learning中的n-way k-shot?_第20张图片
  • 选择少数几个类别,为每个类别选择少量样本

  • 从选出的集合中划分支持集S和测试集Q

  • 通过本次迭代的支持集S来计算测试集上的误差

  • 计算梯度,更新参数

整个过程被称为episode.匹配网络主要是减少测试集B在支持集S上的分类损失

如何理解few-shot learning中的n-way k-shot?_第21张图片
如何理解few-shot learning中的n-way k-shot?_第22张图片

如何理解few-shot learning中的n-way k-shot?_第23张图片
如何理解few-shot learning中的n-way k-shot?_第24张图片

训练完成后,在 novel 类别中再抽样出 S' 和 T',再调用 θ 完成分类任务,当 T' 与 T 相差较大时效果不好。关键的是,匹配网络不需要对它从未见过的类进行任何微调,因为它的非参数性质

实验:

对比模型有原始像素匹配,Baseline Classifier(鉴别特征匹配),MANN,Convolutional Siamese Net。其中Baseline Classifier中的图像分类是训练数据集中的原始类,但排除N个类。然后在最后一层(在softmax之前)的特征进行最近邻匹配

Omniglot(未使用全文本嵌入)

Omniglot数据集包含50个字母表,共计1623类字符,每类包含20个不同人绘制的20个样本。

本文在使用时,添加了90°为倍数的4种旋转,进一步扩展类别数。使用其中的12004类字符作为训练,剩余4234类作为测试

如何理解few-shot learning中的n-way k-shot?_第25张图片

匹配网络的性能超过baselines ,即使对baselines的S'上进行微调后,无论是使用余弦距离还是softmax,baselines的泛化也很好

如何理解few-shot learning中的n-way k-shot?_第26张图片
如何理解few-shot learning中的n-way k-shot?_第27张图片

ImageNet

miniImageNet:选择100类,80类训练,20类测试。每类包含600张84*84的彩色图像。

如何理解few-shot learning中的n-way k-shot?_第28张图片
如何理解few-shot learning中的n-way k-shot?_第29张图片

randImageNet:随机选择118类作为测试集;剩余类作为训练集。

dogsImageNet:选择dogs的118个子类作为测试集;剩余类作为训练集。

如何理解few-shot learning中的n-way k-shot?_第30张图片
如何理解few-shot learning中的n-way k-shot?_第31张图片

结论:

如果训练网络进行 one-shot,那么 one-shot 会容易得多。

神经网络中的非参数结构使网络更容易记忆和适应相同任务中的新数据集

缺点:随着支持集S的大小增长,每个梯度更新的计算变得更加昂贵

参考来源
【平价数据】One Shot Learning_shenxiaolu1984的博客-CSDN博客
Part7 _ Matching Networks_哔哩哔哩_bilibili(有几张的图片来源于该视频)

你可能感兴趣的:(自然语言,深度学习,人工智能)