torch.view打乱tensor顺序使得交叉熵计算出错

pytorch中,torch.view操作不好会打乱tensor的次序,导致计算结果偏差。此处给三个小案例,供参考。
算例1:

entroy=nn.CrossEntropyLoss()
input0 = torch.Tensor([[0.9, 0.2],[0.4, 0.6], [0.3, 0.7]])
input = input0.unsqueeze(0).repeat(2, 1, 1).transpose( 2, 1).contiguous()

print(input)
print(input.shape)
target = torch.tensor([[1,1,0], [1,1,0]])
print(target.shape)

output = entroy(input, target)
print(output)
tensor([[[0.9000, 0.4000, 0.3000],
         [0.2000, 0.6000, 0.7000]],

        [[0.9000, 0.4000, 0.3000],
         [0.2000, 0.6000, 0.7000]]])
torch.Size([2, 2, 3])
torch.Size([2, 3])
tensor(0.8714)

此时logits的配对方式应该如input0所示,即logits中的概率搭配如input0,计算交叉熵的代码可以如下所示:

out0 = -math.log( math.exp(0.2)/(math.exp(0.9)+math.exp(0.2)) )
out1 = -math.log( math.exp(0.6)/(math.exp(0.4)+math.exp(0.6)) )
out2 = -math.log( math.exp(0.3)/(math.exp(0.3)+math.exp(0.7)) )
print( ((out0+out1+out2)/3)*2/2 )

有时会把input与target展平来计算交叉熵损失,如算例2:

entroy=nn.CrossEntropyLoss()
input0=torch.Tensor([[0.9, 0.2],[0.4, 0.6], [0.3, 0.7]])

#input = input.unsqueeze(0).transpose(2, 1).repeat(2, 1, 1).contiguous()
input = input0.unsqueeze(0).repeat(2, 1, 1).contiguous()
print(input.shape)
input = input.view(-1, 2)
print(input)
print(input.shape)
target = torch.tensor([[1,1,0], [1,1,0]])
target = target.view(-1,1)[:, 0]

print(target.shape)
output = entroy(input, target)
print(output)

则结果为:

torch.Size([2, 3, 2])
tensor([[0.9000, 0.2000],
        [0.4000, 0.6000],
        [0.3000, 0.7000],
        [0.9000, 0.2000],
        [0.4000, 0.6000],
        [0.3000, 0.7000]])
torch.Size([6, 2])
torch.Size([6])
tensor(0.8714)

可见展平后损失计算结果相同。
如果在展平之前交换了维度,如算例3:

entroy=nn.CrossEntropyLoss()
input0=torch.Tensor([[0.9, 0.2],[0.4, 0.6], [0.3, 0.7]])

input = input.unsqueeze(0).transpose(2, 1).repeat(2, 1, 1).contiguous()
#input = input0.unsqueeze(0).repeat(2, 1, 1).contiguous()
print(input.shape)
input = input.view(-1, 2)
print(input)
print(input.shape)
target = torch.tensor([[1,1,0], [1,1,0]])
target = target.view(-1,1)[:, 0]

print(target.shape)
output = entroy(input, target)
print(output)
torch.Size([2, 2, 3])
tensor([[0.9000, 0.4000],
        [0.3000, 0.2000],
        [0.6000, 0.7000],
        [0.9000, 0.4000],
        [0.3000, 0.2000],
        [0.6000, 0.7000]])
torch.Size([6, 2])
torch.Size([6])
tensor(0.8210)

此时交叉熵损失计算则变得不正确,这时由于logits的配对被打乱了。

结论:在进行torch.view操作时,要注意保留维度与顺序的一致性(如对torch.Size([2, 2, 3])进行torch.view(-1, 2)操作时会打乱tensor的数据。用torch.Size([2, 3, 2])进行torch.view(-1, 2)操作则不会打乱tensor顺序。

注意:如果在次序打乱的情况下计算loss反向传播,可能导致梯度无法下降。所以在训练时如果发现梯度一直不下降,可以考虑是不是在计算损失时次序弄错了。

你可能感兴趣的:(深度学习,pytorch,深度学习,python)