torch.cat(tensors, dim=0, *, out=None) → Tensor
** tensors** (sequence of Tensors) – 任何相同类型张量的 Python序列
dim (int, optional) – 拼接维度,默认为0
import torch
T1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
T2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
print(T1.shape)
print("=============================================")
print(torch.cat((T1, T2), dim=0).shape)
print("=============================================")
print(torch.cat((T1, T2), dim=0))
结果为
torch.Size([3, 3])
=============================================
torch.Size([6, 3])
=============================================
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
可以看到是沿负y轴方向进行拼接
import torch
T1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
T2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
print(T1.shape)
print("=============================================")
print(torch.cat((T1, T2), dim=1).shape)
print("=============================================")
print(torch.cat((T1, T2), dim=1))
结果为
torch.Size([3, 3])
=============================================
torch.Size([3, 6])
=============================================
tensor([[ 1, 2, 3, 10, 20, 30],
[ 4, 5, 6, 40, 50, 60],
[ 7, 8, 9, 70, 80, 90]])
可以看到是沿X轴正方向进行拼接
与dim=1相同
torch.stack(tensors, dim=0, *, out=None) → Tensor
tensors (sequence of Tensors) – 要连接的张量序列
dim (int) – 拼接维度,默认为0
import torch
T1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
T2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
print(torch.stack((T1, T2), dim=0).shape)
print("=============================================")
print(torch.stack((T1, T2), dim=0))
结果为
torch.Size([2, 3, 3])
=============================================
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[10, 20, 30],
[40, 50, 60],
[70, 80, 90]]])
这个是按3维中的第0维拼接,可参考最上面图
import torch
T1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
T2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
print(torch.stack((T1, T2), dim=1).shape)
print("=============================================")
print(torch.stack((T1, T2), dim=1))
结果为
torch.Size([3, 2, 3])
=============================================
tensor([[[ 1, 2, 3],
[10, 20, 30]],
[[ 4, 5, 6],
[40, 50, 60]],
[[ 7, 8, 9],
[70, 80, 90]]])
import torch
T1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
T2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
print(torch.stack((T1, T2), dim=2).shape)
print("=============================================")
print(torch.stack((T1, T2), dim=2))
结果为
torch.Size([3, 3, 2])
=============================================
tensor([[[ 1, 10],
[ 2, 20],
[ 3, 30]],
[[ 4, 40],
[ 5, 50],
[ 6, 60]],
[[ 7, 70],
[ 8, 80],
[ 9, 90]]])
结果与dim=2相同
部分内容参考以下链接,这里表示感谢 Thanks♪(・ω・)ノ
参考博文1 torch.cat() 官方文档
https://pytorch.org/docs/1.13/generated/torch.cat.html
参考博文2 torch.stack()官方文档
https://pytorch.org/docs/1.13/generated/torch.stack.html
参考博文3 torch.cat()函数的官方解释,详解以及例子
https://blog.csdn.net/xinjieyuan/article/details/105208352
参考博文4 torch.stack()的官方解释,详解以及例子
https://blog.csdn.net/xinjieyuan/article/details/105205326