在进行多分类时,很多时候采用one-hot标签进行计算交叉熵损失,而单纯的交叉熵损失时,只考虑到了正确标签的位置的损失,而忽略了错误标签位置的损失。这样导致模型可能会在训练集上拟合的非常好,但由于其错误标签位置的损失没有计算,导致预测的时候,预测错误的概率比较大,也就是常说的过拟合。
标签平滑可以在一定程度上防止过拟合。
Step1: softmax多分类
P i = e z i ∑ i = 1 n e z i P_i = { e^{z_i} \over {\sum_{i=1}^{n} e^{z_i}} } Pi=∑i=1neziezi
其中, p i p_i pi为当前样本属于类别 i i i的概率, z i z_i zi 指当前样本的对应类别 i i i的 l o g i t logit logit, n表示样本的总列别数。
Step2: 交叉熵损失计算公式:
c r o s s L o s s = − 1 M ∑ m = 1 M ∑ i = 1 n y i l o g p i crossLoss = - {1 \over M} {\sum_{m=1}^M {\sum_{i=1}^n}} y_ilog{p_i} crossLoss=−M1m=1∑Mi=1∑nyilogpi
其中, M M M表示样本综述。
实例:
假设一批样本,样本类别的总数n=5, 其中一个样本的one-hot标签为 [ 0 , 0 , 0 , 1 , 0 ] [0,0,0,1,0] [0,0,0,1,0],假设通过模型(如全连接等)的 l o g i t logit logit进行softmax后的概率矩阵 p p p为:
p = [ 0.1 , 0.1 , 0.1 , 0.36 , 0.34 ] p = [0.1,0.1,0.1, 0.36, 0.34] p=[0.1,0.1,0.1,0.36,0.34]
将其带入到上面的公式,即可计算出单个样本的loss为:
l o s s = − ( 0 ∗ l o g 0.1 + 0 ∗ l o g 0.1 + 0 ∗ l o g 0.1 + 1 ∗ l o g 0.36 + 0 ∗ l o g 0.34 ) = − l o g 0.36 = 1.47 loss = -(0*log0.1+0*log0.1+0*log0.1+1*log0.36+0*log0.34) = -log0.36=1.47 loss=−(0∗log0.1+0∗log0.1+0∗log0.1+1∗log0.36+0∗log0.34)=−log0.36=1.47
这种传统计算交叉熵损失只考虑了正确标签位置的损失,而没有考虑错误标签的损失。下面让我们看看带有标签平滑的交叉熵损失是怎样计算的吧。
同样是上面的例子:一批样本,样本类别的总数n=5, 其中一个样本的one-hot标签为 [ 0 , 0 , 0 , 1 , 0 ] [0,0,0,1,0] [0,0,0,1,0],假设通过模型(如全连接等)的 l o g i t logit logit进行softmax后的概率矩阵 p p p为:
p = [ 0.1 , 0.1 , 0.1 , 0.36 , 0.34 ] p = [0.1,0.1,0.1, 0.36, 0.34] p=[0.1,0.1,0.1,0.36,0.34]
设:标签的平滑因子 ϵ = 0.1 \epsilon=0.1 ϵ=0.1,平滑的计算步骤如下:
y 1 = ( 1 − ϵ ) ∗ [ 0 , 0 , 0 , 1 , 0 ] = [ 0 , 0 , 0 , 0.9 , 0 ] y1 = (1-\epsilon)*[0,0,0,1,0] = [0,0,0,0.9,0] y1=(1−ϵ)∗[0,0,0,1,0]=[0,0,0,0.9,0]
y 2 = ϵ ∗ [ 1 , 1 , 1 , 1 , 1 ] = [ 0.1 , 0.1 , 0.1 , 0.1 , 0.1 ] y2 = \epsilon*[1,1,1,1,1] = [0.1,0.1,0.1,0.1,0.1] y2=ϵ∗[1,1,1,1,1]=[0.1,0.1,0.1,0.1,0.1]
y = y 1 + y 2 = [ 0.1 , 0.1 , 0.1 , 1 , 0.1 ] y = y1+y2 = [0.1,0.1,0.1,1,0.1] y=y1+y2=[0.1,0.1,0.1,1,0.1]
y y y即是平滑后的新标签,然后按照传统的交叉熵损失计算步骤即可,如:
l o s s = − y ∗ l o g p = − [ 0.1 , 0.1 , 0.1 , 1 , 0.1 ] ∗ l o g ( [ 0.1 , 0.1 , 0.1 , 0.36 , 0.34 ] ) = 2.63 loss=-y*logp = -[0.1,0.1,0.1,1,0.1]*log([0.1,0.1,0.1,0.36,0.34])=2.63 loss=−y∗logp=−[0.1,0.1,0.1,1,0.1]∗log([0.1,0.1,0.1,0.36,0.34])=2.63
有上面实例可以看出,带有标签平滑的损失要比传统交叉熵损失要更大。换言之,带有标签平滑的损失要想下降到传统交叉熵损失的程度,就要学习的更好,迫使模型往正确分类的方向走。
只要用到的是交叉熵损失(cross loss),都可以采取标签平滑处理。