torch.stack() 和 torch.cat() 都是拼接tensor常用操作,stack()可以看做并联,cat()为串联。
torch.stack()
官网:https://pytorch.org/docs/stable/torch.html#torch.stack
torch.
stack
(tensors, dim=0, out=None) → Tensor
torch.stack() 将序列连接,形成一个新的tensor结构,此结构中会增加一个维度。连接中的 每个tensor都要保持相同的大小。
参数:
tensors:需要连接的结构 ;dim:需要扩充的维度 ;output:输出的结构
例子:
import torch
l = []
for i in range(0,3):
x = torch.rand(2,3)
l.append(x)
print(l)
x = torch.stack(l,dim=0)
print(x.size())
z = torch.stack(l,dim=1)
print(z.size())
output:
[tensor([[0.3615, 0.9595, 0.5895],
[0.8202, 0.6924, 0.4683]]), tensor([[0.0988, 0.3804, 0.5348],
[0.0712, 0.4715, 0.1307]]), tensor([[0.1635, 0.4716, 0.1728],
[0.8023, 0.9664, 0.4934]])]
torch.Size([3, 2, 3])
torch.Size([2, 3, 3])
torch.cat()
官网:https://pytorch.org/docs/stable/torch.html#torch.cat
torch.
cat
(tensors, dim=0, out=None) → Tensor
参数:同上,tensor必须是相同的维度。
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580,
-1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034,
-0.5790, 0.1497]])
>>> torch.cat(((x,x),x),1) ->报错,不同维度
区别:
下面例子说明torch.cat()与torch.stack()区别。可以看出,stack()是增加新的维度来完成拼接,不改变原维度上的数据大小。cat()是在现有维度上进行数据的增加(改变了现有维度大小),不增加新的维度。
x = torch.rand(2,3)
y = torch.rand(2,3)
print(torch.stack((x,y),1).size())
print(torch.cat((x,y),1).size())
output:
torch.Size([2, 2, 3])
torch.Size([2, 6])