【记录】复刻 pytorch nn.CrossEntropyLoss()

def myCrossEntropyLoss(output, label):
    count = label.size(0)

    loss = 0.0
    for x, l in zip(output, label):
        loss += -1 * x[l] + torch.log(torch.exp(x).sum())
    
    return loss/count

output = torch.randn(10, 5, requires_grad = True) #假设是网络的最后一层,5分类
label = torch.empty(10, dtype=torch.long).random_(5) # 0 - 4, 任意选取一个分类
print(output.shape, label.shape)

print()

loss = myCrossEntropyLoss(output, label)
print('my loss = {:.5f}'.format(loss))

nnCrossEntropyLoss = nn.CrossEntropyLoss()
nnCrossEntropyLossWithIngore = nn.CrossEntropyLoss(ignore_index=0)
loss = nnCrossEntropyLoss(output, label)
loss_with_ignore = nnCrossEntropyLossWithIngore(output, label)
print('torch loss = {:.5f}'.format(loss.data))
print('torch loss with ignore = {:.5f}'.format(loss_with_ignore.data))

你可能感兴趣的:(【记录】复刻 pytorch nn.CrossEntropyLoss())