官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状,注意与torch.cat的区别。
浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠。
import torch
import numpy as np
# 创建3*3的矩阵,a、b
a=np.array([[1,2,3],[4,5,6],[7,8,9]])
b=np.array([[10,20,30],[40,50,60],[70,80,90]])
# 将矩阵转化为Tensor
a = torch.from_numpy(a)
b = torch.from_numpy(b)
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]], dtype=torch.int32)
tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]], dtype=torch.int32)
d = torch.stack((a, b), dim=0)
print(d)
print(d.size())
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[10, 20, 30],
[40, 50, 60],
[70, 80, 90]]], dtype=torch.int32)
torch.Size([2, 3, 3])
当dim = 0,原来的每一个矩阵也变成了一个维度,一个矩阵看做一个整体。有几个矩阵,新的维度就是几,第几个矩阵就是第几维。
d = torch.stack((a, b), dim=1)
print(d)
print(d.size())
tensor([[[ 1, 2, 3],
[10, 20, 30]],
[[ 4, 5, 6],
[40, 50, 60]],
[[ 7, 8, 9],
[70, 80, 90]]], dtype=torch.int32)
torch.Size([3, 2, 3])
将每个矩阵的第一行组成第一维矩阵,依次下去,每个矩阵的第n行组成第n维矩阵。size=(n,i,y)
d = torch.stack((a, b), dim=2)
print(d)
print(d.size())
tensor([[[ 1, 10],
[ 2, 20],
[ 3, 30]],
[[ 4, 40],
[ 5, 50],
[ 6, 60]],
[[ 7, 70],
[ 8, 80],
[ 9, 90]]], dtype=torch.int32)
torch.Size([3, 3, 2])
dim=2的理解可以参考文献【3】
[1]【Pytorch】torch.stack()的使用
[2]看完秒懂torch.stack()
[3]初学torch.stack()对dim的个人理解