作者提出了一种新的Loss来训练零样本的生成网络,梯度匹配Gradient Matching(GM) Loss。该Loss可以测量生成样本产生的梯度信号的质量。使用GMLoss可以迫使生成器产生的样本能够训练出精确的分类器。
生成未见类的样本(仅使用未见类的嵌入向量),将zsl问题转化成监督学习问题。
分类期的精度太依赖于生成样本的多样性和保真度。因此本文关注两个方向:
1)利用合成样本来隐式建模每个未见类的流形
2)保证合成样本训练出来的分类器效果好
L W G A N = E x ∼ P r [ D ( x ) ] − E x ~ ∼ P g [ D ( x ~ ) ] + λ E x ^ ∼ P x ^ [ ( ∣ ∣ ∇ x ^ D ( x ^ ) ∣ ∣ 2 − 1 ) 2 ] (1) \begin{aligned} \mathcal{L}_{WGAN} =& \mathop{\mathbb{E}}\limits_{x \sim P_r } [\mathcal D(x)] - \mathop{\mathbb{E}}\limits_{\tilde x \sim P_g } [\mathcal D(\tilde x)] \\ & + \lambda \mathop{\mathbb{E}}\limits_{\hat x \sim P_{\hat x } } [ (|| \mathcal \nabla_{\hat x} D(\hat x)||_2 - 1)^2] \end{aligned} \tag{1} LWGAN=x∼PrE[D(x)]−x~∼PgE[D(x~)]+λx^∼Px^E[(∣∣∇x^D(x^)∣∣2−1)2](1)
P r P_r Pr是真实的数据分布, P g P_g Pg是生成器的输出, x ^ \hat x x^是 x x x和 x ~ \tilde x x~的插值。
一般来说,生成器的输入是一个先验分布里采样的噪声。然而,我们想要生成未见类的训练样本,所以要结合噪声和类别嵌入。一个简答的方法就是直接连接噪声和类别嵌入。然而,我们也可以建模类别对应的隐含分布,从该隐含分布中采样训练样本。因此,本文提出了条件多远高斯分布 N ( μ ( a ) , Σ ( a ) ) \mathcal{N}(\mu(a), \Sigma(a)) N(μ(a),Σ(a))。
作者认为分类loss测量了合成样本在预训练分类模型上的损失,而不是期望的在合成样本上训练的模型的损失。基于以上观察,相对于产生那些使预训练模型准确分类的样本,应该着力学习生成能够训练精确的分类模型的训练样本。
一种直接的方案是通过最小化一个分类模型在合成样本上的损失来训练生成器。但是这样做效果很差,原因有两个:
因此,我们把关注点放在:最大化每一次模型更新的准确性。我们观察到一个简单的事实:一个生成模型学习真实类别流形的情形下,在一个大的生成样本集合上训练的分类模型参数的损失函数的偏导数是和一个大的真实训练样本的集合上训练的分类模型参数的偏导数是高度相关的。
疑问:如果真实样本的初始点和合成样本的初始点不一样,那么梯度方向一致,没有任何意义???
基于这些观察,我们提出最小化在可见类生成样本上获得的梯度。更具体地说,我们提出学习一个生成模型,它可以最大化真实样本上的梯度和合成样本上的梯度的相关性。
g r ( θ ) = E ( x , a ) ∼ D [ ∇ θ f L C L S ( f , x , a ; θ f = θ ) ] , (2) g_r(\theta)=\mathop{\mathbb{E}} \limits_{(x,a) \sim \mathcal{D}} [\nabla_{\theta_f} \mathcal{L}_{CLS}(f, x,a; \theta_f = \theta)], \tag{2} gr(θ)=(x,a)∼DE[∇θfLCLS(f,x,a;θf=θ)],(2)
g s ( θ ) = E x ~ ∼ G ( a ∼ A s ) [ ∇ θ f L C L S ( f , x , a ; θ f = θ ) ] , (3) g_s(\theta)=\mathop{\mathbb{E}} \limits_{\tilde{x} \sim \mathcal{G}(a \sim {\mathcal A}_s)} [\nabla_{\theta_f} \mathcal{L}_{CLS}(f, x,a; \theta_f = \theta)], \tag{3} gs(θ)=x~∼G(a∼As)E[∇θfLCLS(f,x,a;θf=θ)],(3)
其中, L C L S ( f , x , a ) \mathcal{L}_{CLS}(f, x,a) LCLS(f,x,a) 是用来训练相容性函数 f ( x , a ; θ f ) f(x, a; \theta_f) f(x,a;θf)的损失函数。训练过程中,我们会近似样本batch上的 g r g_r gr和 g s g_s gs。
因为梯度的意义是指向局部极小值的方向,而不是梯度的绝对大小。所以作者采用余弦相似度来度量 g r g_r gr和 g s g_s gs的差异。
L G M = E θ [ 1 − g r ( θ ) T g s ( θ ) ∣ ∣ g r ( θ ) ∣ ∣ 2 ∣ ∣ g s ( θ ) ∣ ∣ 2 ] , (4) \mathcal{L}_{GM} = \mathop{\mathbb{E}}\limits_{\theta} [1 - \frac{g_r(\theta)^Tg_s(\theta)}{||g_r(\theta)||_2||g_s(\theta)||_2}], \tag{4} LGM=θE[1−∣∣gr(θ)∣∣2∣∣gs(θ)∣∣2gr(θ)Tgs(θ)],(4)