傻傻分清stack(A,dim=?)和stack((A,B),dim=?)

stack(A,dim=?)
一般此A针对的是list,里面有同样维度大小的tensor数据。
意义是将这些list内部的“同样维度大小的tensor数据”在规定的dim上拼接,得到的是tensor。

import torch
a = torch.Tensor([1,2,3])
b = torch.Tensor([4,5,6])
c = torch.Tensor([7,8,9])
out1 = [a,b,c]           #type(out1)=list
print("type(out1)",type(out1))
#stack(A,dim=?)
out1 = torch.stack(out1, dim=1)
print(out1.shape)
print("out1",out1)

结果:

type(out1) <class 'list'>
torch.Size([3, 3])
out1 tensor([[1., 4., 7.],
        [2., 5., 8.],
        [3., 6., 9.]])

stack((A,B),dim=?)
注意这里写成stack((A,B),dim=?)或者stack([A,B],dim=?)都对。
意义是将这些“同样维度大小的tensor-A、B”在规定的dim上拼接,
其中A,B 是tensor。输出也是tensor。

import torch
a = torch.Tensor([1,2,3])
b = torch.Tensor([4,5,6])
c = torch.Tensor([7,8,9])

#stack((A,B),dim=?) or #stack)[A,B]),dim=?)
out2 = torch.stack((a,b), dim=1)
print(out2.shape)
print("out2",out2)

out3 = torch.stack([a,b,c], dim=1)
print(out3.shape)
print("out3",out3)
torch.Size([3, 2])
out2 tensor([[1., 4.],
        [2., 5.],
        [3., 6.]])
torch.Size([3, 3])
out3 tensor([[1., 4., 7.],
        [2., 5., 8.],
        [3., 6., 9.]])

out1与out3是一致的,所以这两种torch.stack只是不同在于输入,后者输入的是两个一样维度的tensor,前者输入的是tensor接起来(比如.append得到)的list。

题外话:至于dim的问题,注意stack()后多了一个维度。
stack()比cat()多了一个维度,可以去搜“stack()比cat()的区别。”

引用别的博主的话:stack()
官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠。

你可能感兴趣的:(pytorch,深度学习,python)