注:参考博客Pytorch中的torch.cat()函数。本人在其基础上增加了更为详细的解释。
torch.cat((A,B),axis)
是对A
, B
两个tensor
进行拼接。
参数axis指定拼接的方式。axis=0
为按行拼接;axis=1
为按列拼接。
拼接的时候把待拼接的tensor
视作整体。(注意在示例中理解这句话)
import torch
# 初始化三个 tensor
A=torch.ones(2,3) #2x3的张量(矩阵)
# tensor([[ 1., 1., 1.],
# [ 1., 1., 1.]])
B=2*torch.ones(4,3) #4x3的张量(矩阵)
# tensor([[ 2., 2., 2.],
# [ 2., 2., 2.],
# [ 2., 2., 2.],
# [ 2., 2., 2.]])
D=2*torch.ones(2,4) # 2x4的张量(矩阵)
# tensor([[ 2., 2., 2., 2.],
# [ 2., 2., 2., 2.],
# 按维数0(行)拼接 A 和 B
C=torch.cat((A,B),0)
# tensor([[ 1., 1., 1.],
# [ 1., 1., 1.],
# [ 2., 2., 2.],
# [ 2., 2., 2.],
# [ 2., 2., 2.],
# [ 2., 2., 2.]])
print(C.shape)
# torch.Size([6, 3])
# 按维数1(列)拼接 A 和 D
C=torch.cat((A,D),1)
# tensor([[ 1., 1., 1., 2., 2., 2., 2.],
# [ 1., 1., 1., 2., 2., 2., 2.]])
print(C.shape)
# torch.Size([2, 7])
另外,torch.cat((A,B),axis)
还能把list
中的tensor
拼接起来。
import torch
x = torch.Tensor([1, 2, 3])
x = x.unsqueeze(1)
x2 = torch.cat( [ x*2 for i in range (1,4) ], 1 )
# tensor([[2., 2., 2.],
# [4., 4., 4.],
# [6., 6., 6.]])
x = torch.Tensor([1, 2, 3])
生成的x
的shape
为torch.Size([3])
,我们需要用x = x.unsqueeze(1)
为x
增加第二个维度,使其变为二维的tensor
:torch.Size([3, 1])
。
关于升维函数 x.unsqueeze(axis)
和降维函数 x.unsqueeze(axis)
的详细说明,可以去我的另一篇博客增加维度或者减少维度 ——a.squeeze(axis) 和 a.unsqueeze(axis)
torch.cat( [ x*2 for i in range (1,4) ], 1 )
先生成了一个包含 3 个tensor
的list
,然后对list
中的元素按列拼接(axis=1
)。
关于range这个基本的函数,见本人的另一篇博客创建一个整数列表—— range()
故最后的结果为:
# tensor([[2., 2., 2.],
# [4., 4., 4.],
# [6., 6., 6.]])