pytorch中,torch.view操作不好会打乱tensor的次序,导致计算结果偏差。此处给三个小案例,供参考。
算例1:
entroy=nn.CrossEntropyLoss()
input0 = torch.Tensor([[0.9, 0.2],[0.4, 0.6], [0.3, 0.7]])
input = input0.unsqueeze(0).repeat(2, 1, 1).transpose( 2, 1).contiguous()
print(input)
print(input.shape)
target = torch.tensor([[1,1,0], [1,1,0]])
print(target.shape)
output = entroy(input, target)
print(output)
tensor([[[0.9000, 0.4000, 0.3000],
[0.2000, 0.6000, 0.7000]],
[[0.9000, 0.4000, 0.3000],
[0.2000, 0.6000, 0.7000]]])
torch.Size([2, 2, 3])
torch.Size([2, 3])
tensor(0.8714)
此时logits的配对方式应该如input0所示,即logits中的概率搭配如input0,计算交叉熵的代码可以如下所示:
out0 = -math.log( math.exp(0.2)/(math.exp(0.9)+math.exp(0.2)) )
out1 = -math.log( math.exp(0.6)/(math.exp(0.4)+math.exp(0.6)) )
out2 = -math.log( math.exp(0.3)/(math.exp(0.3)+math.exp(0.7)) )
print( ((out0+out1+out2)/3)*2/2 )
有时会把input与target展平来计算交叉熵损失,如算例2:
entroy=nn.CrossEntropyLoss()
input0=torch.Tensor([[0.9, 0.2],[0.4, 0.6], [0.3, 0.7]])
#input = input.unsqueeze(0).transpose(2, 1).repeat(2, 1, 1).contiguous()
input = input0.unsqueeze(0).repeat(2, 1, 1).contiguous()
print(input.shape)
input = input.view(-1, 2)
print(input)
print(input.shape)
target = torch.tensor([[1,1,0], [1,1,0]])
target = target.view(-1,1)[:, 0]
print(target.shape)
output = entroy(input, target)
print(output)
则结果为:
torch.Size([2, 3, 2])
tensor([[0.9000, 0.2000],
[0.4000, 0.6000],
[0.3000, 0.7000],
[0.9000, 0.2000],
[0.4000, 0.6000],
[0.3000, 0.7000]])
torch.Size([6, 2])
torch.Size([6])
tensor(0.8714)
可见展平后损失计算结果相同。
如果在展平之前交换了维度,如算例3:
entroy=nn.CrossEntropyLoss()
input0=torch.Tensor([[0.9, 0.2],[0.4, 0.6], [0.3, 0.7]])
input = input.unsqueeze(0).transpose(2, 1).repeat(2, 1, 1).contiguous()
#input = input0.unsqueeze(0).repeat(2, 1, 1).contiguous()
print(input.shape)
input = input.view(-1, 2)
print(input)
print(input.shape)
target = torch.tensor([[1,1,0], [1,1,0]])
target = target.view(-1,1)[:, 0]
print(target.shape)
output = entroy(input, target)
print(output)
torch.Size([2, 2, 3])
tensor([[0.9000, 0.4000],
[0.3000, 0.2000],
[0.6000, 0.7000],
[0.9000, 0.4000],
[0.3000, 0.2000],
[0.6000, 0.7000]])
torch.Size([6, 2])
torch.Size([6])
tensor(0.8210)
此时交叉熵损失计算则变得不正确,这时由于logits的配对被打乱了。
结论:在进行torch.view操作时,要注意保留维度与顺序的一致性(如对torch.Size([2, 2, 3])进行torch.view(-1, 2)操作时会打乱tensor的数据。用torch.Size([2, 3, 2])进行torch.view(-1, 2)操作则不会打乱tensor顺序。
注意:如果在次序打乱的情况下计算loss反向传播,可能导致梯度无法下降。所以在训练时如果发现梯度一直不下降,可以考虑是不是在计算损失时次序弄错了。