torch.stack() 与 torch.cat()

torch.stack() 和 torch.cat() 都是拼接tensor常用操作,stack()可以看做并联,cat()为串联。

torch.stack() 

官网:https://pytorch.org/docs/stable/torch.html#torch.stack

torch.stack(tensorsdim=0out=None) → Tensor 

torch.stack() 将序列连接,形成一个新的tensor结构,此结构中会增加一个维度。连接中的 每个tensor都要保持相同的大小。

参数:

tensors:需要连接的结构 ;dim:需要扩充的维度 ;output:输出的结构

例子:

import torch

l = []
for i in range(0,3):
    x = torch.rand(2,3)
    l.append(x)
print(l)

x = torch.stack(l,dim=0)

print(x.size())

z = torch.stack(l,dim=1)
print(z.size())

output:
[tensor([[0.3615, 0.9595, 0.5895],
        [0.8202, 0.6924, 0.4683]]), tensor([[0.0988, 0.3804, 0.5348],
        [0.0712, 0.4715, 0.1307]]), tensor([[0.1635, 0.4716, 0.1728],
        [0.8023, 0.9664, 0.4934]])]
torch.Size([3, 2, 3])
torch.Size([2, 3, 3])

torch.cat() 

官网:https://pytorch.org/docs/stable/torch.html#torch.cat

torch.cat(tensorsdim=0out=None) → Tensor

参数:同上,tensor必须是相同的维度。

>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497],
        [ 0.6580, -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580,
         -1.0969, -0.4614],
        [-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034,
         -0.5790,  0.1497]])
>>> torch.cat(((x,x),x),1) ->报错,不同维度

区别:

下面例子说明torch.cat()与torch.stack()区别。可以看出,stack()是增加新的维度来完成拼接,不改变原维度上的数据大小。cat()是在现有维度上进行数据的增加(改变了现有维度大小),不增加新的维度。

x = torch.rand(2,3)
y = torch.rand(2,3)
print(torch.stack((x,y),1).size())
print(torch.cat((x,y),1).size())

output:
torch.Size([2, 2, 3])
torch.Size([2, 6])

 

你可能感兴趣的:(pytorch,torch.stack(),torch.cat())