Pytorch使用permute后再reshape会修改张量的值,维度虽然对齐但仍需小心

楼主最近阅读一篇论文的代码,在修改的时候发现loss下降正常但评测指标怎么都上不去,经过两天的排查找到原因。
这篇论文的代码在模型预测之后使用了permute和reshape,导致我模型的预测和预想的完全不一样,示例如下:

predict=model(train).permute(0,1,3,2).reshape(-1,n,66)

虽然表面上看用不用permute最后经过reshape得到的维度都是一样的,也能对其,后续工作也能展开,但问题就在这里,使用了permute后再reshape里面的值会完全不一样,导致虽然没有bug但后续工作受阻

举个例子:
先创建一个[2,3,3]的张量:

Pytorch使用permute后再reshape会修改张量的值,维度虽然对齐但仍需小心_第1张图片
然后一个使用permute一个不用,最后都使用reshape得到一样的维度
Pytorch使用permute后再reshape会修改张量的值,维度虽然对齐但仍需小心_第2张图片
可以看到,虽然a和b都是3✖6的张量,但内容却完全不一样了,所以有时候维度虽然对齐了也是有问题的

导致这种情况的原因是permute会直接修改矩阵的维度,permute(1,0,2)相当于交换0,1两个维度,即将[2,3,3]改变为[3,2,3],原来是2块,每块里面3行3列,现在是3块,每块里面2行3列,之前两块里面的每两行组成了新的一块,而reshape则是将一块一块拆开后重新排列组合,最终导致了值的不同:
Pytorch使用permute后再reshape会修改张量的值,维度虽然对齐但仍需小心_第3张图片
所以如果在模型预测之后你的代码依然有permute函数和reshape函数,需要注意后续的评测标准,因为模型预测的值可能已经被修改了,最后你可能会因为精确度太差而怀疑是不是模型写错了

你可能感兴趣的:(坑)