动机/问题:广义零样本学习的技术难点。对已见类数据的过拟合导致对目标类别(已见类和未见类)的不确定预测,是GZSL性能低的原因。
如图,经过校正的网络预测更加准确。
问题:预测更加准确,是否能够提升分类精度?能够在实验中给出前后对比吗?
符号:
已见数据 D = { ( x n , y ) n ) } n = 1 N \mathcal{D} = \{ (x_n, y)n) \}_{n=1}^N D={(xn,y)n)}n=1N
源类别 S = { 1 , ⋯   , S } \mathcal{S}=\{ 1, \cdots, S \} S={1,⋯,S}
目标类 T = { S + 1 , ⋯   , S + T } \mathcal{T}=\{ S+1, \cdots, S+T \} T={S+1,⋯,S+T}, 训练时样本不可见
一个类别 c ∈ { S ∪ T } c \in \{\mathcal{S \cup T}\} c∈{S∪T}的语义表示为 a c ∈ R Q a_c \in \mathbb R^Q ac∈RQ
所有类别的语义表示 A = { a c } c = 1 S + T \mathcal{A}=\{a_c\}_{c=1}^{S+T} A={ac}c=1S+T
未见类数据 D ′ = { x m } m = N + 1 N + M \mathcal{D'} = \{ x_m \}_{m=N+1}^{N+M} D′={xm}m=N+1N+M, 源类或者目标类数据
定义1:零样本,ZSL Given D \mathcal{D} D and { a c } c = 1 S \{a_c\}_{c=1}^{S} {ac}c=1S, classify D \mathcal{D} D over target classes T \mathcal{T} T.
定义2:广义零样本,GZSL Given D \mathcal{D} D and { a c } c = 1 S + T \{a_c\}_{c=1}^{S+T} {ac}c=1S+T of both source and target classes, learn a model f : x ↦ y f: x \mapsto y f:x↦y to classify D ′ \mathcal{D'} D′ over both source and target classes S ∪ T \mathcal{S \cup T} S∪T.
在这个定义里,ZSL没有利用目标域的标签。
图像 x ∈ D x \in \mathcal{D} x∈D
特征嵌入 ϕ ( x ) ∈ R K \phi(x) \in \mathbb R^K ϕ(x)∈RK
类别语义 a ∈ A a \in \mathcal{A} a∈A,属性或者词向量
语义嵌入 ψ ( a ) ∈ R K \psi(a) \in \mathbb R^K ψ(a)∈RK
这里的嵌入空间就是特征空间,论文给出的是2048维的ResNet特征或者1024维的GoogleNet特征
图像的视觉嵌入 z n = ϕ ( x n ) z_n = \phi(x_n) zn=ϕ(xn)
类别的语义嵌入 v c = ψ ( a c ) v_c = \psi(a_c) vc=ψ(ac)
预测函数
f c ( x n ) = s i m ( ϕ ( x n ) , ψ ( a c ) ) f_c(x_n) = \rm sim(\phi(x_n), \psi(a_c)) fc(xn)=sim(ϕ(xn),ψ(ac))
s i m ( . , . ) \rm sim(., .) sim(.,.)是相似度函数,比如內积和余弦相似度; f c ( x n ) f_c(x_n) fc(xn)是(nearest prototype classifier) NPC分类器分配给图像 x n x_n xn类别 c c c的强度。
图像 x n x_n xn的预测类别 y ( x n ) y(x_n) y(xn)为
y ( x n ) = arg max c f c ( x n ) y(x_n)=\arg \max_c f_c(x_n) y(xn)=argcmaxfc(xn)
论文提到,预测源类和目标类的导致的技术难度是不一样的。
multi-class Hinge loss
∑ n = 1 N ∑ c = 1 S = max ( 0 , Δ ( y n , c ) + f c ( x n ) − f y n ( x n ) ) \sum_{n=1}^{N}\sum_{c=1}^{S}=\max (0, \Delta(y_n, c) + f_c(x_n)-f_{y_n}(x_n) ) n=1∑Nc=1∑S=max(0,Δ(yn,c)+fc(xn)−fyn(xn))
其中,间隔定义为
Δ ( y n , c ) = { 0 y n = c 1 y n ! = c \Delta(y_n, c) = \begin{cases} 0& {y_n = c}\\ 1& {y_n != c} \end{cases} Δ(yn,c)={01yn=cyn!=c
文中提到大部分零样本学习方法使用多分类Hinge损失来学习视觉语义映射。
作者应用温度校正来缓解由于在已见数据上的过拟合导致的对源域类别的过分相信。温度校正是Hinton老爷子提出来从深度网络蒸馏知识的。作者应用温度校正来将预测 f f f转换到源于类别上的概率分布
p c ( x n ) = exp ( f c ( x n ) / τ ) ∑ c ′ = 1 S exp ( f c ′ ( x n ) / τ ) p_c(x_n) = \frac {\exp(f_c(x_n)/\tau)} {\sum_{c'=1}^{S} \exp(f_{c'}(x_n)/\tau)} pc(xn)=∑c′=1Sexp(fc′(xn)/τ)exp(fc(xn)/τ)
其中, τ \tau τ就是温度,当 τ = 1 \tau=1 τ=1是深度网络里最常见的选项。温度 τ \tau τ用 τ > 1 \tau>1 τ>1“软化”了softmax。当 τ → ∞ \tau \to \infty τ→∞时,概率 p c → 1 / S p_c \to 1/S pc→1/S,这将导致最大的不确定性。当 τ → 0 \tau \to 0 τ→0时,概率坍缩到一点(即 p c = 1 p_c = 1 pc=1)。因为 τ \tau τ不改变softmax函数的最大值,收敛后如果应用温度校正 τ ≠ 1 \tau \neq 1 τ̸=1。
将概率 p c p_c pc插入到源域类别 S S S的可见数据 D \mathcal D D上的交叉熵损失得到
(6) L = − ∑ n = 1 N ∑ c = 1 S y n , c log p c ( x n ) . L = -\sum_{n=1}^{N} \sum_{c=1}^{S} y_{n, c} \log{p_c(x_n)}. \tag{6} L=−n=1∑Nc=1∑Syn,clogpc(xn).(6)
关于这个loss,作者认为,相比于multi-class Hinge loss,虽然交叉熵是一个很简单的处理多分类的方案,但能够利用温度校正来缓解过拟合。
不管是ZSL还是GZSL,都强调了模型训练不能使用目标域训练数据。但是,要用模型识别目标域的数据,必须让模型学习目标域的知识。所以就只能用到目标域的语义信息。
作者提出,将模型的预测 f c f_c fc转换成目标域上的概率(带有温度校正)。
(7) q c ( x n ) = exp ( f c ( x n ) / τ ) ∑ c ′ = S + 1 S + T exp ( f c ′ ( x n ) / τ ) q_c(x_n) = \frac {\exp (f_c(x_n)/\tau)} {\sum_{c'=S+1}^{S+T} \exp(f_{c'}(x_n)/\tau)} \tag{7} qc(xn)=∑c′=S+1S+Texp(fc′(xn)/τ)exp(fc(xn)/τ)(7)
温度校正 τ ≠ 1 \tau \neq 1 τ̸=1在公式(6)和(7)的端到端的训练中都会用到。
解释
直观上讲,目标域 c c c和源域图片 x n x_n xn对应的源域越相似,概率 q c ( x n ) q_c(x_n) qc(xn)的值越大。这样就避免了训练时源域图像对目标域图像的不确定性一致。在信息论中,熵 h ( q ) = − q log q h(q)=-q\log{q} h(q)=−qlogq是对分布 q q q的不确定性的度量。值越低,不确定性越小。在本文中,作者提出了基于熵准则的不确定性校正的目标函数:
(8) H = − ∑ n = 1 N ∑ c = S + 1 S + T q c ( x n ) log q c ( x n ) . H = -\sum_{n=1}^{N} \sum_{c=S+1}^{S+T} q_c(x_n) \log{q_c(x_n)}. \tag{8} H=−n=1∑Nc=S+1∑S+Tqc(xn)logqc(xn).(8)
需要实验去看看,这个效果怎么样
优化目标如下:
(9) min L + λ H + γ Ω ( ϕ , ψ ) , \min { L + \lambda{H} + \gamma{\Omega (\phi, \psi)} }, \tag{9} minL+λH+γΩ(ϕ,ψ),(9)
Ω ( ϕ , ψ ) \Omega (\phi, \psi) Ω(ϕ,ψ)是模型复杂度惩罚项。在深度学习中,可以用权值衰减来替代它。
GZSL的精度比ZSL低很多,为什么?
- 源域的精度低,是为什么?
- 目标域精度低,是为什么?模型对源域过拟合。