【PyTorch】TypeError: stack(): argument 'tensors' (position 1) must be tuple

0x00 前言

近期有个版本适配的任务,说白了就是把 PyTorch 0.3.0 的代码更新适配 PyTorch 1.0.2
PyTorch 的向上兼容性在此时就可以体现出来了,令人欣慰的是直接升级版本后并没有太多报错,
其中一个比较突出的问题就是 torch.stack()torch.cat() 的变化。

0x01 错误演示 @ torch.stack()

>>> x = torch.randn(3, 2, 3)
>>> torch.stack(x, 0)

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-16-97ff2667b2e9> in <module>()
      1 import torch
      2 x = torch.randn(3, 2, 3)
----> 3 torch.stack(x, 0)

TypeError: stack(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor

现在 torch.stack() 的输入必须是一个 python-list ,里面每个元素都是一个 tensor,
不可以和原先 0.3.0 的时候一样,输入一个矩阵直接按某一个维度进行 stack 操作。

0x02 错误演示 @ torch.cat()

>>> x = torch.randn(3)
>>> x
tensor([ 0.5710,  0.4324, -0.5154])

>>> x.sum()
tensor(0.4879)

>>> torch.cat([x.sum(), x.sum()], 0)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-22-38992e81433c> in <module>()
----> 1 torch.cat([x.sum(), x.sum()], 0)

RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated

torch.cat() 现在不允许将 0 维 tensor 拼在一起了,如果有类似需求可以考虑直接新建一个 Tensor。

0x03 正确示范

>>> x = torch.randn(2, 3)
>>> x

 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 2x3]

>>> torch.cat([x, x, x], 0)

 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 6x3]

>>> torch.cat((x, x, x), 1)

 0.5983 -0.0341  2.4918  0.5983 -0.0341  2.4918  0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735  1.5981 -0.5265 -0.8735  1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 2x9]

0x04 代码更新 DIFF实例

刚学pytorch的时候写的代码,现在看起来……
好想重构啊,算了任务太多来不及以后再说吧

torch.stack()

# Original
t1 = torch.stack(t.repeat(seg_len, 1, 1, 1), 2).view(-1, hid_size)
t2 = torch.stack(t.repeat(seg_len, 1, 1, 1), 1).view(-1, hid_size)

# New
rep_t = [t] * seg_len
t1 = torch.stack(rep_t, 2).view(-1, hid_size)
t2 = torch.stack(rep_t, 1).view(-1, hid_size)

torch.cat()

# Original
macro_count = torch.stack(
            [torch.cat([positive_true.sum(), positive_false.sum()], 0),
             torch.cat([negative_true.sum(), negative_false.sum()], 0)], 1)
# New
macro_count = torch.LongTensor([
            [positive_true.sum(), positive_false.sum()],
            [negative_true.sum(), negative_false.sum()]])

你可能感兴趣的:(DIY)