# -*- coding: utf-8 -*-
"""
Created on 2022/10/8 20:04
@author: Janben
"""
import torch
import torch.nn.functional as Function
import numpy as np
def labalsmoothing(x, target, GetTrueDist=False):
'''
:param x: 直接从模型输出的预测结果,没有进行softmax也没有log,形状为(batchsize,class_num)
:param target: 真实标签,形状为(batchsize,)
:return:
'''
smoothing = 0.1
class_num = 5
true_dist = x.data.clone()
true_dist.fill_(smoothing / (class_num - 1))
confidence = 1-smoothing
true_dist.scatter_(1, target.data.unsqueeze(1), confidence) # 此行代码实现了标签平滑
print(true_dist)
# 计算交叉熵,以下是计算交叉熵的公式,当smoothing=0时,结果与库函数nn.CrossEntropyLoss()(x,target)一样,但库函数的输入要求是原始的预测值,不要进行softmax和log
logprobs = Function.log_softmax(x, dim=1) # softmax+log
print(logprobs)
mean_loss = -torch.sum(true_dist * logprobs) / x.size(0) # 除以样本数量
if GetTrueDist == True:
return mean_loss, true_dist
else:
return mean_loss
x = [[-1.69077412, 0.94173985, 0.80724937, 0.69399691, 0.00955748],
[-0.18597058, 0.02671462, 0.55668541, -0.17869505, -0.43943024],
[-0.01852173, 1.30501542, 1.12420786, -0.40676145, -0.19358048],
[ 1.71318642, -0.58123289, 2.37872253, -0.70580169, -0.75736412],
[-0.61055401, 0.49647492, 1.55212542, 0.85372002, 0.09467156],
[-0.04895827, -0.13194447, 1.86764062, -0.35986444, -0.46494589],
[-0.11814479, 0.25389836, 0.7644965 , -1.52282513, -0.95201391],
[ 0.34373188, -1.29832594, -0.46132988, 1.73043535, -0.69572854]]#np.random.randn(8,5)
tar = [4, 3, 4, 1, 1, 0, 4, 2]#np.random.randint(0,5,(8,))
a = labalsmoothing(torch.tensor(x),torch.tensor(tar))
print(a)
标签平滑方法实质就是将标签转化为正类置信度的分布,相当于对标签进行``软化'',使标签不再``非0即1''。将标签平滑处理之后采用交叉熵计算训练损失,也可以用KL散度torch.nn.KLDivLoss计算损失。计算损失就是计算训练的预测值的分布与标签平滑后的分布的差异性。
我们将平滑值$smoothing$设置为0.1,表明在进行标签平滑时按正类置信度为90\%(1-0.1=0.9)将原始标签进行变换。
假设有5个类别,batchsize=8,标签序列为[4, 3, 4, 1, 1, 0, 4, 2],按0.1进行标签平滑后的结果为
tensor([[0.0250, 0.0250, 0.0250, 0.0250, 0.9000],
[0.0250, 0.0250, 0.0250, 0.9000, 0.0250],
[0.0250, 0.0250, 0.0250, 0.0250, 0.9000],
[0.0250, 0.9000, 0.0250, 0.0250, 0.0250],
[0.0250, 0.9000, 0.0250, 0.0250, 0.0250],
[0.9000, 0.0250, 0.0250, 0.0250, 0.0250],
[0.0250, 0.0250, 0.0250, 0.0250, 0.9000],
[0.0250, 0.0250, 0.9000, 0.0250, 0.0250]])