stack在英文中有“堆叠的意思”。所以stack通常是把一些低纬(二维)的tensor堆叠为一个高维(三维)的tensor。
stack()官方解释:torch.stack[source] → Tensor :
函数目的: 沿着一个新维度对输入张量序列进行拼接 。其中序列中所有的 张量 都应该为相同形状。
outputs = torch.stack(inputs, dim=0) # → Tensor
参数:
例子
T1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
T2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
R0 = torch.stack((T1, T2), dim=0)
print("R0:\n", R0)
print("R0.shape:\n", R0.shape)
"""
R0:
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[10, 20, 30],
[40, 50, 60],
[70, 80, 90]]])
R0.shape:
torch.Size([2, 3, 3])
"""
R1 = torch.stack((T1, T2), dim=1)
print("R1.shape:\n", R1.shape)
"""
R1:
tensor([[[ 1, 2, 3],
[10, 20, 30]],
[[ 4, 5, 6],
[40, 50, 60]],
[[ 7, 8, 9],
[70, 80, 90]]])
R1.shape:
torch.Size([3, 2, 3])
"""
R2 = torch.stack((T1, T2), dim=2)
print("R2:\n", R2)
print("R2.shape:\n", R2.shape)
"""
R2:
tensor([[[ 1, 10],
[ 2, 20],
[ 3, 30]],
[[ 4, 40],
[ 5, 50],
[ 6, 60]],
[[ 7, 70],
[ 8, 80],
[ 9, 90]]])
R2.shape:
torch.Size([3, 3, 2])
"""
R3 = torch.stack((T1, T2), dim=3)
print("R3:\n", R3)
print("R3.shape:\n", R3.shape)
"""
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
"""
一般torch.cat()是为了把函数torch.stack()得到tensor进行拼接而存在的。torch.cat() 和python中的内置函数cat(), 在使用和目的上,是没有区别的,区别在于前者操作对象是tensor。
函数目的:
在给定维度上对输入的张量序列seq 进行连接操作。
outputs = torch.cat(inputs, dim=0) → Tensor
参数:
重点:
例子
x1 = torch.tensor([[11, 21, 31], [21, 31, 41]], dtype=torch.int)
print("x1:\n", x1)
print("x1.shape:\n", x1.shape)
'''
x1:
tensor([[11, 21, 31],
[21, 31, 41]], dtype=torch.int32)
x1.shape:
torch.Size([2, 3])
'''
x2 = torch.tensor([[12, 22, 32], [22, 32, 42]])
print("x2:\n", x2)
print("x2.shape:\n", x2.shape)
'''
x2:
tensor([[12, 22, 32],
[22, 32, 42]])
x2.shape:
torch.Size([2, 3])
'''
inputs = [x1, x2]
print("inputs:\n", inputs)
'''
inputs:
[tensor([[11, 21, 31],
[21, 31, 41]], dtype=torch.int32), tensor([[12, 22, 32],
[22, 32, 42]])]
'''
R0 = torch.cat(inputs, dim=0)
print("R0:\n", R0)
print("R0.shape:\n", R0.shape)
'''
R0:
tensor([[11, 21, 31],
[21, 31, 41],
[12, 22, 32],
[22, 32, 42]])
R0.shape:
torch.Size([4, 3])
'''
R1 = torch.cat(inputs, dim=1)
print("R1:\n", R1)
print("R1.shape:\n", R1.shape)
'''
R1:
tensor([[11, 21, 31, 12, 22, 32],
[21, 31, 41, 22, 32, 42]])
R1.shape:
torch.Size([2, 6])
'''
R2 = torch.cat(inputs, dim=2)
print("R2:\n", R2)
print("R2.shape:\n", R2.shape)
'''
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
'''
参考链接:https://blog.csdn.net/qq_40507857/article/details/119854085