stack(A,dim=?)
一般此A针对的是list,里面有同样维度大小的tensor数据。
意义是将这些list内部的“同样维度大小的tensor数据”在规定的dim上拼接,得到的是tensor。
import torch
a = torch.Tensor([1,2,3])
b = torch.Tensor([4,5,6])
c = torch.Tensor([7,8,9])
out1 = [a,b,c] #type(out1)=list
print("type(out1)",type(out1))
#stack(A,dim=?)
out1 = torch.stack(out1, dim=1)
print(out1.shape)
print("out1",out1)
结果:
type(out1) <class 'list'>
torch.Size([3, 3])
out1 tensor([[1., 4., 7.],
[2., 5., 8.],
[3., 6., 9.]])
stack((A,B),dim=?)
注意这里写成stack((A,B),dim=?)或者stack([A,B],dim=?)都对。
意义是将这些“同样维度大小的tensor-A、B”在规定的dim上拼接,
其中A,B 是tensor。输出也是tensor。
import torch
a = torch.Tensor([1,2,3])
b = torch.Tensor([4,5,6])
c = torch.Tensor([7,8,9])
#stack((A,B),dim=?) or #stack)[A,B]),dim=?)
out2 = torch.stack((a,b), dim=1)
print(out2.shape)
print("out2",out2)
out3 = torch.stack([a,b,c], dim=1)
print(out3.shape)
print("out3",out3)
torch.Size([3, 2])
out2 tensor([[1., 4.],
[2., 5.],
[3., 6.]])
torch.Size([3, 3])
out3 tensor([[1., 4., 7.],
[2., 5., 8.],
[3., 6., 9.]])
out1与out3是一致的,所以这两种torch.stack只是不同在于输入,后者输入的是两个一样维度的tensor,前者输入的是tensor接起来(比如.append得到)的list。
题外话:至于dim的问题,注意stack()后多了一个维度。
stack()比cat()多了一个维度,可以去搜“stack()比cat()的区别。”
引用别的博主的话:stack()
官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠。