深入推导理解sample softmax loss

文章目录

    • 何为logit
    • loss中的logit修正
    • NCE loss
    • sample softmax loss
    • sample softmax推导过程
    • 理解logit修正
    • 部分参考资料

查看TensorFlow关于nce loss和sample softmax loss时,发现都会对logit进行修正。为了搞清楚为什么需要对logit修正,以及为什么可以这样修正,参考许多资料后,现总结成文,以飨诸君。

何为logit

Mathematically, the logit is the inverse(反函数) of the standard logistic function σ ( x ) = 1 1 + e − x \sigma(x)=\frac{1}{1+e^{-x}} σ(x)=1+ex1,so the logit is define as:

l o g i t ( p ) = σ − 1 ( p ) = l n ( p 1 − p ) f o r p ∈ ( 0 , 1 ) . logit(p) = \sigma^{-1}(p)=ln(\frac{p}{1-p})\quad for \quad p \in (0,1). logit(p)=σ1(p)=ln(1pp)forp(0,1).

因此,logit 也称为对数赔率(log-odds),因为它等于赔率    ( p 1 − p )    \;(\frac{p}{1-p})\; (1pp)的对数,其中 p p p是概率。

推导如下:

l e t    y = l o g i t ( x ) = l n x 1 − x e y = x 1 − x 1 + e y = 1 − x 1 − x + x 1 − x 1 + e y = 1 1 − x x = 1 − 1 1 + e y x = e y 1 + e y x = 1 1 + e − y \begin{align*} let \; y &= logit(x)=ln\frac{x}{1-x} \\ e^y &= \frac{x}{1-x} \\ 1+e^y &= \frac{1-x}{1-x}+\frac{x}{1-x}\\ 1+e^y &=\frac{1}{1-x} \\ x &= 1-\frac{1}{1+e^y} \\ x &= \frac{e^y}{1+e^y} \\ x &= \frac{1}{1+e^{-y}} \end{align*} letyey1+ey1+eyxxx=logit(x)=ln1xx=1xx=1x1x+1xx=1x1=11+ey1=1+eyey=1+ey1

即,logit作为sigmoid函数的输入。

loss中的logit修正

需要对logit修正。修正量与负采样得到的相同y的概率 Q ( y ∣ x ) Q(y|x) Q(yx)有关。

修正后的logit如下形式,表示x与y的匹配度:

G ( x , y ) = F ( x , y ) − l o g    Q ( y ∣ x ) G(x,y) = F(x,y) - log\;Q(y|x) G(x,y)=F(x,y)logQ(yx)

NCE loss

NCE-based methods are used for estimating the parameters of a statistical distribution by differentiating between “real data” and “noise”.

对于NCE而言,是将多分类问题转化为了若干个二分类问题。因此,修正后的logit,作为sigmoid函数的输入,以进一步计算binary交叉熵损失。

L N C E = − [ ∑ y i ∈ T i l o g ( σ ( G ( x i , y i ) ) ) + ∑ y i ′ ∈ S i l o g ( 1 − σ ( G ( x i , y i ′ ) ) ) ] = ∑ y i ∈ T i l o g ( 1 + e x p ( − G ( x i , y i ) ) ) + ∑ y i ′ ∈ S i l o g ( 1 + e x p ( G ( x i , y i ′ ) ) ) \begin{align*} L_{NCE} &= -[\sum_{y_i \in T_i}log(\sigma(G(x_i,y_i))) + \sum_{y_i' \in S_i}log(1- \sigma(G(x_i,y_i')))] \\ &= \sum_{y_i \in T_i}log(1+exp(-G(x_i,y_i))) + \sum_{y_i' \in S_i}log(1 + exp(G(x_i,y_i'))) \end{align*} LNCE=[yiTilog(σ(G(xi,yi)))+yiSilog(1σ(G(xi,yi)))]=yiTilog(1+exp(G(xi,yi)))+yiSilog(1+exp(G(xi,yi)))

sample softmax loss

与NCE一样的修正公式,也就是说模型得到 F ( x , y ) F(x,y) F(x,y)(比如user embedding和item embedding的点积)之后,再根据负采样到 y y y的概率 Q ( y ∣ x ) Q(y|x) Q(yx)进行修正,修正后的数值才喂入softmax计算loss:

L S a m p l e S o f t m a x = − l o g    e x p ( G ( x i , y i ) ) ∑ y i ′ ∈ C i e x p ( G ( x i , y i ′ ) ) = − G ( x i , y i ) + l o g    ∑ y i ′ ∈ C i e x p ( G ( x i , y i ′ ) ) \begin{align*} L_{SampleSoftmax} &= -log\;\frac{exp(G(x_i,y_i))}{\sum_{y_i' \in C_i} exp(G(x_i,y_i'))} \\ &= -G(x_i,y_i) + log\;\sum_{y_i' \in C_i} exp(G(x_i,y_i')) \end{align*} LSampleSoftmax=logyiCiexp(G(xi,yi))exp(G(xi,yi))=G(xi,yi)+logyiCiexp(G(xi,yi))

注意:由于预测的时候不需要计算loss,因此不需要修正logit。


sample softmax推导过程

如何理解sample softmax中的 G ( x , y ) G(x,y) G(x,y)

以u2i场景描述该问题,给定一个用户 x i x_i xi,其点击的物料是 t i t_i ti,再给他按照 Q ( y , x ) Q(y,x) Q(y,x)采样一批负样本 S i S_i Si。原始的softmax问题是,在整个物料库 I I I中那个item是 x i x_i xi点击的。通过sample softmax,问题演变为在 x i x_i xi的候选集 C i = { t i } ∪ S i C_i=\{t_i\}\cup S_i Ci={ti}Si中,正确挑选出 t i t_i ti的概率是多少,即建模 P ( y = t i ∣ x i , C i ) P(y=t_i|x_i,C_i) P(y=tixi,Ci)

假设我们聚焦于第i个样本,以下公式省略下标i,那么根据条件概率公式展开,有:

P ( y ∣ x , C ) = P ( y , x , C ) P ( x , C ) = P ( C ∣ x , y ) P ( x , y ) P ( C ∣ x ) P ( x ) = P ( C ∣ x , y ) P ( y ∣ x ) P ( x ) P ( C ∣ x ) P ( x ) = P ( C ∣ x , y ) P ( y ∣ x ) P ( C ∣ x ) \begin{align*} P(y|x,C) &= \frac{P(y,x,C)}{P(x,C)} \\ &= \frac{P(C|x,y)P(x,y)}{P(C|x)P(x)} \\ &= \frac{P(C|x,y)P(y|x)P(x)}{P(C|x)P(x)} \\ &= \frac{P(C|x,y)P(y|x)}{P(C|x)} \end{align*} P(yx,C)=P(x,C)P(y,x,C)=P(Cx)P(x)P(Cx,y)P(x,y)=P(Cx)P(x)P(Cx,y)P(yx)P(x)=P(Cx)P(Cx,y)P(yx)

其中, P ( C ∣ x , y ) P(C|x,y) P(Cx,y)表示在用户 x x x和某一个物料 y y y给定的情况下,构成整个物料候选集 C C C的概率。由于 C = { t } ∪ S C=\{t\}\cup S C={t}S,因此它等价于 C C C中每个物料被采样到的概率,与 I − C I-C IC中每个物料未被采样到的概率的乘积。即:
P ( C ∣ x , y ) = ∏ y ′ ∈ C − y Q ( y ′ ∣ x ) × ∏ y ′ ∈ I − C 1 − Q ( y ′ ∣ x ) \begin{align*} P(C|x,y) = \prod_{y'\in C-y}Q(y'|x) \times \prod_{y'\in I-C} 1 - Q(y'|x) \end{align*} P(Cx,y)=yCyQ(yx)×yIC1Q(yx)

由此可得:

P ( y ∣ x , C ) = P ( C ∣ x , y ) P ( y ∣ x ) P ( C ∣ x ) = P ( y ∣ x ) × ∏ y ′ ∈ C − y Q ( y ′ ∣ x ) × ∏ y ′ ∈ I − C 1 − Q ( y ′ ∣ x ) P ( C ∣ x ) = P ( y ∣ x ) Q ( y , x ) × ∏ y ′ ∈ C Q ( y ′ ∣ x ) × ∏ y ′ ∈ I − C 1 − Q ( y ′ ∣ x ) P ( C ∣ x ) = P ( y ∣ x ) Q ( y , x ) × K ( x , C ) \begin{align*} P(y|x,C) &= \frac{P(C|x,y)P(y|x)}{P(C|x)} \\ &= \frac{P(y|x) \times \prod_{y'\in C-y}Q(y'|x) \times \prod_{y'\in I-C} 1 - Q(y'|x)}{P(C|x)} \\ &= \frac{P(y|x)}{Q(y,x)} \times \frac{\prod_{y'\in C}Q(y'|x) \times \prod_{y'\in I-C} 1 - Q(y'|x)}{P(C|x)} \\ &= \frac{P(y|x)}{Q(y,x)} \times K(x,C) \end{align*} P(yx,C)=P(Cx)P(Cx,y)P(yx)=P(Cx)P(yx)×yCyQ(yx)×yIC1Q(yx)=Q(y,x)P(yx)×P(Cx)yCQ(yx)×yIC1Q(yx)=Q(y,x)P(yx)×K(x,C)

其中, ∏ y ′ ∈ C Q ( y ′ ∣ x ) × ∏ y ′ ∈ I − C 1 − Q ( y ′ ∣ x ) P ( C ∣ x ) \frac{\prod_{y'\in C}Q(y'|x) \times \prod_{y'\in I-C} 1 - Q(y'|x)}{P(C|x)} P(Cx)yCQ(yx)×yIC1Q(yx)是与当前预测的 y y y无关,因此等价于一个与 x x x C C C有关的常数 K ( x , C ) K(x,C) K(x,C)

进一步,两边取对数可得:

P ( y ∣ x , C ) = P ( y ∣ x ) Q ( y , x ) × K ( x , C ) l o g    P ( y ∣ x , C ) = l o g    P ( y ∣ x ) − l o g    Q ( y ∣ x ) + K ′ ( x , C ) \begin{align*} P(y|x,C) &= \frac{P(y|x)}{Q(y,x)} \times K(x,C) \\ log\;P(y|x,C) &= log\;P(y|x) - log\;Q(y|x) + K'(x,C) \end{align*} P(yx,C)logP(yx,C)=Q(y,x)P(yx)×K(x,C)=logP(yx)logQ(yx)+K(x,C)

基于此,loss表示为:

L S a m p l e S o f t m a x = − l o g    e x p ( G ( x , y ) ) ∑ y ′ ∈ C i e x p ( G ( x , y ′ ) ) = − G ( x , y ) + l o g    ∑ y ′ ∈ C i e x p ( G ( x , y ′ ) ) \begin{align*} L_{SampleSoftmax} &= -log\;\frac{exp(G(x,y))}{\sum_{y' \in C_i} exp(G(x,y'))} \\ &= -G(x,y) + log\;\sum_{y' \in C_i} exp(G(x,y')) \end{align*} LSampleSoftmax=logyCiexp(G(x,y))exp(G(x,y))=G(x,y)+logyCiexp(G(x,y))

其中, l o g    P ( y ∣ x ) log\;P(y|x) logP(yx)可以理解为模型输出的 x x x y y y的匹配度 F ( x , y ) F(x,y) F(x,y)。与候选物料 y y y无关的常数 K K K不影响softmax的结果,因此最后需要喂入softmax的 x , y x,y x,y匹配度需写成 G ( x , y ) G(x,y) G(x,y),即:

G ( x , y ) = F ( x , y ) − l o g    Q ( y ∣ x ) \begin{align*} G(x,y) = F(x,y) - log\;Q(y|x) \end{align*} G(x,y)=F(x,y)logQ(yx)


理解logit修正

为了方便理解推导采样后的logit的修正公式,可以先考虑原始类别全集 I I I上logit与概率的关系。即:
P ( y ∣ x ) = e x p ( F ( x , y ) ) ∑ y ′ ∈ I e x p ( F ( x , y ′ ) ) \begin{align*} P(y|x) &= \frac{exp(F(x,y))}{\sum_{y'\in I} exp(F(x, y'))} \\ \end{align*} P(yx)=yIexp(F(x,y))exp(F(x,y))

其中, F ( x , y ) F(x,y) F(x,y)为模型输出的 x x x y y y的匹配度(比如user embedding和item embedding的点积),即给定输入x,输出类别为y的logits。上式两边同时取对数,有:

l o g    P ( y ∣ x ) = l o g    e x p ( F ( x , y ) ) ∑ y ′ ∈ I e x p ( F ( x , y ′ ) ) l o g    P ( y ∣ x ) = F ( x , y ) − l o g ∑ y ′ ∈ I e x p ( F ( x , y ′ ) ) l o g    P ( y ∣ x ) = F ( x , y ) − K ( x ) \begin{align*} log\;P(y|x) &= log\;\frac{exp(F(x,y))}{\sum_{y'\in I} exp(F(x, y'))} \\ log\;P(y|x) &= F(x,y) - log \sum_{y'\in I} exp(F(x, y')) \\ log\;P(y|x) &= F(x,y) - K(x) \\ \end{align*} logP(yx)logP(yx)logP(yx)=logyIexp(F(x,y))exp(F(x,y))=F(x,y)logyIexp(F(x,y))=F(x,y)K(x)

采样后类别子集 C i C_i Ci上的logit与原始类别全集 I I I上logit的关系?基于以上的推导,概率与logit之间存在以下关系:

原始类别全集 I I I上logit与概率的关系:
l o g    P ( y ∣ x ) = F ( x , y ) − K 1 ( x ) \begin{align} log\;P(y|x) &= F(x,y) - K_1(x) \end{align} logP(yx)=F(x,y)K1(x)

同理可得,采样后类别子集 C i C_i Ci上logit与概率的关系:
l o g    P ( y ∣ x , C ) = F ( x , y ∣ C ) − K 2 ( x ) \begin{align} log\;P(y|x,C) &= F(x,y|C) - K_2(x) \end{align} logP(yx,C)=F(x,yC)K2(x)

又根据条件概率计算所得的采样后类别子集 C i C_i Ci与全集 I I I上的概率分布关系,

l o g    P ( y ∣ x , C ) = l o g    P ( y ∣ x ) − l o g    Q ( y ∣ x ) + K ′ ( x , C ) \begin{align} log\;P(y|x,C) &= log\;P(y|x) - log\;Q(y|x) + K'(x,C) \end{align} logP(yx,C)=logP(yx)logQ(yx)+K(x,C)

由(1)、(2)、(3)式可得:

l o g    P ( y ∣ x , C ) = l o g    P ( y ∣ x ) − l o g    Q ( y ∣ x ) + K ′ ( x , C ) l o g    P ( y ∣ x , C ) = F ( x , y ) − K 1 ( x ) − l o g    Q ( y ∣ x ) + K ′ ( x , C ) F ( x , y ∣ C ) − K 2 ( x ) = F ( x , y ) − K 1 ( x ) − l o g    Q ( y ∣ x ) + K ′ ( x , C ) F ( x , y ∣ C ) = F ( x , y ) − l o g    Q ( y ∣ x ) + K ′ ( x , C ) − K 1 ( x ) + K 2 ( x ) F ( x , y ∣ C ) = F ( x , y ) − l o g    Q ( y ∣ x ) + C o n s t \begin{align*} log\;P(y|x,C) &= log\;P(y|x) - log\;Q(y|x) + K'(x,C) \\ log\;P(y|x,C) &= F(x,y) - K_1(x) - log\;Q(y|x) + K'(x,C) \\ F(x,y|C) - K_2(x) &= F(x,y) - K_1(x) - log\;Q(y|x) + K'(x,C) \\ F(x,y|C) &= F(x,y) - log\;Q(y|x) + K'(x,C) - K_1(x) + K_2(x) \\ F(x,y|C) &= F(x,y) - log\;Q(y|x) + Const \end{align*} logP(yx,C)logP(yx,C)F(x,yC)K2(x)F(x,yC)F(x,yC)=logP(yx)logQ(yx)+K(x,C)=F(x,y)K1(x)logQ(yx)+K(x,C)=F(x,y)K1(x)logQ(yx)+K(x,C)=F(x,y)logQ(yx)+K(x,C)K1(x)+K2(x)=F(x,y)logQ(yx)+Const

由此推导出的公式便是我们进行采样后的logit,即 F ( x , y ∣ C ) F(x,y|C) F(x,yC)与原始logit 即 F ( x , y ) F(x,y) F(x,y)之间的关系。
这里的修正量便是 l o g    Q ( y ∣ x ) log\;Q(y|x) logQ(yx)。具体使用方式如下:

  1. 通过 Q ( y ∣ x ) Q(y|x) Q(yx)对类别采样,得到一个类别子集 C i C_i Ci
  2. 模型对采样类别子集 C i C_i Ci中的类别分别计算logit,即 F ( x , y ) F(x,y) F(x,y)
  3. 对于计算出的 F ( x , y ) F(x,y) F(x,y),减去修正量 Q ( y ∣ x ) Q(y|x) Q(yx),得到采样后子集的logit,即 F ( x , y ∣ C ) F(x,y|C) F(x,yC)
  4. 使用 F ( x , y ∣ C ) F(x,y|C) F(x,yC)作为softmax的输入,计算不同类别的概率分布。同时使用 F ( x , y ∣ C ) F(x,y|C) F(x,yC)计算loss进行梯度下降。

部分参考资料

https://zhuanlan.zhihu.com/p/528862933

https://zhuanlan.zhihu.com/p/143830417

https://zhuanlan.zhihu.com/p/539913423

你可能感兴趣的:(深度学习,机器学习,深度学习,人工智能)