匹配网络 Matching Network

匹配网络 Matching Network

匹配网络其实就是引入注意力机制,通过对 embedding 后的特征计算注意力,利用注意力得分进行分析:
匹配网络 Matching Network_第1张图片
首先也是对支持集和查询集进行 embedding,然后用查询集样本对每个支持集样本计算注意力
a ( x ^ , x i ) = e c ( f ( x ^ ) , g ( x i ) ) / ∑ j = 1 k e c ( f ( x ^ ) , g ( x j ) ) a\left(\hat{x}, x_{i}\right)=e^{c\left(f(\hat{x}), g\left(x_{i}\right)\right)} / \sum_{j=1}^{k} e^{c\left(f(\hat{x}), g\left(x_{j}\right)\right)} a(x^,xi)=ec(f(x^),g(xi))/j=1kec(f(x^),g(xj))

其中:

  • f 和 g是我们选择的合适的神经网络,一般 f = g,用于输入的 embedding
  • x i x_i xi 是支持集, x ^ \hat x x^ 是查询集
  • c 是余弦距离

计算了注意力之后,就分析查询集的样本:
P ( y ^ ∣ x ^ , S ) = ∑ i = 1 k a ( x ^ , x i ) y i P(\hat{y} \mid \hat{x}, S)=\sum_{i=1}^{k} a\left(\hat{x}, x_{i}\right) y_{i} P(y^x^,S)=i=1ka(x^,xi)yi
其中:

  • y i y_i yi 是每个类别的标签,其实就是把每个类别根据注意力得分进行线性加权
  • P 是计算出对应类别的概率

最后的训练目标为:
θ = arg ⁡ max ⁡ θ E L ∼ T [ E S ∼ L , B ∼ L [ ∑ ( x , y ) ∈ B log ⁡ P θ ( y ∣ x , S ) ] ] \theta=\arg \max _{\theta} E_{L \sim T}\left[E_{S \sim L, B \sim L}\left[\sum_{(x, y) \in B} \log P_{\theta}(y \mid x, S)\right]\right] θ=argθmaxELTESL,BL(x,y)BlogPθ(yx,S)

个人总结:
总的来说,匹配网络把整个分析的过程都简化到注意力计算过程中,如果某个类别的注意力得分比较高,其实就意味着测试样本属于这个类别的可能性比较大,所以模型的训练重点就回到最初的 embedding 中。

Vinyals O, Blundell C, Lillicrap T, et al. Matching networks for one shot learning[J]. Advances in neural information processing systems, NIPS 2016, 29: 3630-3638.
元学习系列(四):Matching Network(匹配网络)

你可能感兴趣的:(论文笔记,匹配网络,小样本学习)