多分类问题的soft cross entropy 损失函数

在做多分类问题的时候,分类结果的损失函数经常使用交叉熵损失函数,对于预测结果,先经过softmax,然后经过log,然后再根据One hot向量只取得其中的一个值作为损失函数的吸收值,比如,logsoftmax后的值为[-0.6, -0.12 , -0.33, -0.334, -0.783],假设one hot label 为[ 0,0,0,0,1 ],则损失函数的值为 Loss = 0.783,,也就是说,只有一个值纳入了计算,我就在想,可不可以将所有的值都纳入计算呢,如果这样的话,就得将label转为soft label ,为 [0.2, 0.2 , 0.2, 0.2, 1],将0 Label 的地方设置为 1/(label.shape[0]),再进行计算损失,则这个对应的损失函数的pytorch实现如下所示:

class softcrossentropy(nn.Module):
    def __init__(self):
        super( softcrossentropy, self ).__init__()
        self.cel = nn.CrossEntropyLoss()
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, inputs, labels):
        loss_tra_cel = self.cel( inputs, labels )

        ls = (-1) * self.logsoftmax(inputs)
        ls_sum = torch.sum(ls, dim=1) / inputs.shape[1]

        loss_soft_cel = torch.sum( ls_sum ) / inputs.shape[0]

        loss = loss_tra_cel + loss_soft_cel
        return loss

     分别利用上面的softccrossentroy和正常的crossentropy作为分类损失函数损失函数,对market1501进行行人重试别的训练,记录每个mini batch的train loss 、train accuracy、val loss 、val accuracy,在训练过程中将这些数据都保存起来,后期进行比对分析。训练的数据在这了:链接: https://pan.baidu.com/s/18jYFFd9LaPrbd72OdHjdpw  密码: hg9r

利用训练保留的数据进行后期分析比对,文件目录如下所示:

多分类问题的soft cross entropy 损失函数_第1张图片

mat.py的代码如下:

import scipy.io
import matplotlib.pyplot as plt
# train_acc
# train_loss
# val_loss
# val_acc

soft_res = scipy.io.loadmat('soft_mini_batch_data.mat')
soft_train_acc = soft_res['train_acc']

tra_res = scipy.io.loadmat('tra_mini_batch_data.mat')
tra_train_acc = tra_res['train_acc']

soft_acc_list = soft_train_acc[0].tolist()
train_acc_list = tra_train_acc[0].tolist()

data_length = int(len( soft_acc_list )/10)
x_label = [ i for i in range(data_length) ]

fig, ax = plt.subplots()
ax.plot( x_label, soft_acc_list[0:data_length],label = 'soft' )
ax.plot( x_label, train_acc_list[0:data_length], label = 'tra' )
ax.set_xlabel('mini batch num')
ax.set_ylabel('accuracy')
ax.set_title('god bless')
ax.legend()
plt.show()

      这里比较了在不同的criterion下使用同样的方法,对同样的数据进行训练的过程中的几个数值的变化情况。这个代码只比较了每个Mini batch的准确率情况。对比图如下所示:

多分类问题的soft cross entropy 损失函数_第2张图片

你可能感兴趣的:(pytorch,行人重识别,cross,entropy)