极简笔记 Meta-Learning for semi-supervised few-shot classification

极简笔记 Meta-Learning for semi-supervised few-shot classification

论文地址 https://arxiv.org/pdf/1803.00676.pdf

本篇文章核心是给出了一种用于少样本半监督学习的分类算法。总体思路非常简单,通过一个网络 (Prototypical Network)提取特征,之后对特征进行聚类,聚类中心稳定之后拿去测试,测试时把测试样本按照更新了的聚类中心归过去,就完事了。

训练集中样本又分为两个集合 (S,R) ( S , R ) ,分别表示有类别标注数据和无标注数据。网络映射记作 h(x) h ( x ) , 有标注样本 xi x i , 无标注样本 x˜i x ~ i , 类别中心记作 pc p c , 更新过程中记作 p˜c p ~ c ,更新过程是典型的EM算法,和K-Means非常像,即把没标注的数据先归到最近类,更新聚类中心,再重新把没标注数据归类,迭代直到不变。
聚类过程
如果直接用K-Means聚类那这篇文章就不用发ICLR了。文章用了三种聚类算法:

soft K-Means

聚类中心更新公式

p˜c=ih(xi)zi,c+jh(x˜j)z˜j,cizi,c+jz˜j,cmj,c,where z˜j,c=exp(||h(x˜jpc||22)cexp(||h(x˜j)pc||22) p ~ c = ∑ i h ( x i ) z i , c + ∑ j h ( x ~ j ) z ~ j , c ∑ i z i , c + ∑ j z ~ j , c m j , c , where  z ~ j , c = e x p ( − | | h ( x ~ j − p c | | 2 2 ) ∑ c ′ e x p ( − | | h ( x ~ j ) − p c ′ | | 2 2 )

soft K-Means with cluster

由于无标注样本集合R中一些样本真实类别并没有出现在集合S中,这类样本被称作distractor class。为了防止此类样本污染标注类别集合,作者修改了聚类规则,认为distractor class类别中心始终在原点:

pc={ih(xi)zi,cizi,c0for c=1...Nfor c=N+1 p c = { ∑ i h ( x i ) z i , c ∑ i z i , c for  c = 1... N 0 for  c = N + 1

此外再考虑引进类别半径表示类内样本的不一致性(为了方便起见,标注类别半径 r1...N=1 r 1... N = 1 ,只学习无标注样本类别半径 rN+1 r N + 1
z˜j,c=exp(1r2c||x˜jpc||22A(rc))cexp(1r2c||x˜jpc||22A(rc)),where A(r)=12log(2π)+log(r) z ~ j , c = e x p ( − 1 r c 2 | | x ~ j − p c | | 2 2 − A ( r c ) ) ∑ c ′ e x p ( − 1 r c ′ 2 | | x ~ j − p c ′ | | 2 2 − A ( r c ′ ) ) , where  A ( r ) = 1 2 l o g ( 2 π ) + l o g ( r )

masked soft K-Means

定义样本 x˜j x ~ j 到类别 c c 的距离:

d˜j,c=dj,c1Mjdj,c,where dj,c=||h(x˜j)pc||22 d ~ j , c = d j , c 1 M ∑ j d j , c , where  d j , c = | | h ( x ~ j ) − p c | | 2 2

另外再用MLP学习两个阈值 βc,γc β c , γ c (也许这就是为什么文章title叫meta-learning的原因了。。。)
[βc,γc]=MLP([minj(d˜j,c),maxj(d˜j,c),varj(d˜j,c),skewj(d˜j,c),kurtj(d˜j,c)]) [ β c , γ c ] = M L P ( [ m i n j ( d ~ j , c ) , m a x j ( d ~ j , c ) , v a r j ( d ~ j , c ) , s k e w j ( d ~ j , c ) , k u r t j ( d ~ j , c ) ] )

然后是聚类中心的更新公式:
p˜c=ih(xi)zi,c+jh(x˜j)z˜j,cmj,cizi,c+jz˜j,cmj,c,where mj,c=sigmoid(γc(d˜j,cβc)) p ~ c = ∑ i h ( x i ) z i , c + ∑ j h ( x ~ j ) z ~ j , c m j , c ∑ i z i , c + ∑ j z ~ j , c m j , c , where  m j , c = s i g m o i d ( − γ c ( d ~ j , c − β c ) )

最后来一张三种聚类方法的效果比对,其中Supervised表示只进行监督学习,并且直接分类;Semi-supervised inference表示只监督学习,在测试过程中聚类计算。这俩方法都是baseline。
这里写图片描述
这里写图片描述

你可能感兴趣的:(极简笔记)