torch.cat 简略用法

1 torch.cat用途

用于将张量concatnate拼接,可按行、列拼接

2 按行列拼接(注意张量维度)

C = torch.cat( (A,B),0 ) #按维数0拼接(接着行竖着拼)
C = torch.cat( (A,B),1 ) #按维数1拼接(接着列横着拼)

示例

import torch
A=torch.ones(3,3)
B=3*torch.ones(2,3)
C=2*B
D=torch.cat((A,B,C),0)
print(D)
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [3., 3., 3.],
        [3., 3., 3.],
        [6., 6., 6.],
        [6., 6., 6.]])

按列拼接会出问题,所以一定注意维度

D=torch.cat((A,B,C),1)

err: RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1.

3 默认按行拼接(注意括号)

D=torch.cat((A,B,C))
D=torch.cat([A,B,C])
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [3., 3., 3.],
        [3., 3., 3.],
        [6., 6., 6.],
        [6., 6., 6.]])

否则会出现下面的问题
TypeError: cat(): argument 'tensors' (position 1) must be tuple of Tensors, not Tensor
TypeError: cat() takes from 1 to 2 positional arguments but 3 were given

你可能感兴趣的:(pytorch)