近期有个版本适配的任务,说白了就是把 PyTorch 0.3.0
的代码更新适配 PyTorch 1.0.2
,
PyTorch 的向上兼容性在此时就可以体现出来了,令人欣慰的是直接升级版本后并没有太多报错,
其中一个比较突出的问题就是 torch.stack()
和 torch.cat()
的变化。
>>> x = torch.randn(3, 2, 3)
>>> torch.stack(x, 0)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-16-97ff2667b2e9> in <module>()
1 import torch
2 x = torch.randn(3, 2, 3)
----> 3 torch.stack(x, 0)
TypeError: stack(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor
现在 torch.stack()
的输入必须是一个 python-list ,里面每个元素都是一个 tensor,
不可以和原先 0.3.0 的时候一样,输入一个矩阵直接按某一个维度进行 stack 操作。
>>> x = torch.randn(3)
>>> x
tensor([ 0.5710, 0.4324, -0.5154])
>>> x.sum()
tensor(0.4879)
>>> torch.cat([x.sum(), x.sum()], 0)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-22-38992e81433c> in <module>()
----> 1 torch.cat([x.sum(), x.sum()], 0)
RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated
torch.cat()
现在不允许将 0 维 tensor 拼在一起了,如果有类似需求可以考虑直接新建一个 Tensor。
>>> x = torch.randn(2, 3)
>>> x
0.5983 -0.0341 2.4918
1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 2x3]
>>> torch.cat([x, x, x], 0)
0.5983 -0.0341 2.4918
1.5981 -0.5265 -0.8735
0.5983 -0.0341 2.4918
1.5981 -0.5265 -0.8735
0.5983 -0.0341 2.4918
1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 6x3]
>>> torch.cat((x, x, x), 1)
0.5983 -0.0341 2.4918 0.5983 -0.0341 2.4918 0.5983 -0.0341 2.4918
1.5981 -0.5265 -0.8735 1.5981 -0.5265 -0.8735 1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 2x9]
刚学pytorch的时候写的代码,现在看起来……
好想重构啊,算了任务太多来不及以后再说吧
# Original
t1 = torch.stack(t.repeat(seg_len, 1, 1, 1), 2).view(-1, hid_size)
t2 = torch.stack(t.repeat(seg_len, 1, 1, 1), 1).view(-1, hid_size)
# New
rep_t = [t] * seg_len
t1 = torch.stack(rep_t, 2).view(-1, hid_size)
t2 = torch.stack(rep_t, 1).view(-1, hid_size)
# Original
macro_count = torch.stack(
[torch.cat([positive_true.sum(), positive_false.sum()], 0),
torch.cat([negative_true.sum(), negative_false.sum()], 0)], 1)
# New
macro_count = torch.LongTensor([
[positive_true.sum(), positive_false.sum()],
[negative_true.sum(), negative_false.sum()]])