元学习系列(四):Matching Network(匹配网络)

对一个小孩子来说,只要你展示了一次斑马的样子,以后他就能指出什么是斑马了,整个学习的过程只有一个样本,但是对深度学习算法来说还远远达不到这种学习程度,所以研究如何通过小样本甚至一个样本进行学习,就成为了few-shot、 one-shot learning的目的。

当然这里的小样本指的是某一类的样本比较小,比如要分辨猫狗鸡,可能鸡的样本只有几个,但是猫狗比较多,在这种情况下模型如何学习才能更好地分辨出鸡,就是小样本学习了。

关系网络其实就是引入注意力机制,通过对embedding后的特征计算注意力,利用注意力得分进行分析:

元学习系列(四):Matching Network(匹配网络)_第1张图片
首先也是对支持集和查询集进行embedding,然后用查询集样本对每个支持集样本计算注意力:

a ( x ^ , x i ) = e c ( f ( x ^ ) , g ( x i ) ) / ∑ j = 1 k e c ( f ( x ^ ) , g ( x j ) ) a(\hat x, x_i) = e^{c(f(\hat x), g(x_i))}/\sum_{j=1}^k e^{c(f(\hat x), g(x_j))} a(x^,xi)=ec(f(x^),g(xi))/j=1kec(f(x^),g(xj))

其中x hat是查询集,xi是支持集,c是余弦距离。

计算了注意力之后,就分析查询集的样本:

y ^ = ∑ i = 1 k a ( x ^ , x i ) y i \hat y = \sum_{i=1}^k a(\hat x , x_i) y_i y^=i=1ka(x^,xi)yi

yi是每个类别的标签,其实就是把每个类别根据注意力得分进行线性加权。

总的来说,我觉得匹配网络把整个分析的过程都简化到注意力计算的过程中,如果某个类别的注意力得分比较高,其实就意味着测试样本属于这个类别的可能性比较大,所以模型的训练重点就回到最初的embedding了。

在github写的自然语言处理入门教程,持续更新:NLPBeginner

在github写的机器学习入门教程,持续更新:MachineLearningModels

想浏览更多关于数学、机器学习、深度学习的内容,可浏览本人博客

你可能感兴趣的:(元学习)