[PyTorch] 拼接多个tensor:torch.cat((A,B),axis)

注:参考博客Pytorch中的torch.cat()函数。本人在其基础上增加了更为详细的解释。

torch.cat((A,B),axis)是对A, B两个tensor进行拼接。

参数axis指定拼接的方式。axis=0为按行拼接;axis=1为按列拼接。

拼接的时候把待拼接的tensor视作整体。(注意在示例中理解这句话)

import torch

# 初始化三个 tensor
A=torch.ones(2,3)    #2x3的张量(矩阵)                                     
# tensor([[ 1.,  1.,  1.],
#         [ 1.,  1.,  1.]])
B=2*torch.ones(4,3)  #4x3的张量(矩阵)                                    
# tensor([[ 2.,  2.,  2.],
#         [ 2.,  2.,  2.],
#         [ 2.,  2.,  2.],
#         [ 2.,  2.,  2.]])
D=2*torch.ones(2,4)	 # 2x4的张量(矩阵)
# tensor([[ 2.,  2.,  2., 2.],
#         [ 2.,  2.,  2., 2.],

# 按维数0(行)拼接 A 和 B
C=torch.cat((A,B),0) 
# tensor([[ 1.,  1.,  1.],
#          [ 1.,  1.,  1.],
#          [ 2.,  2.,  2.],
#          [ 2.,  2.,  2.],
#          [ 2.,  2.,  2.],
#          [ 2.,  2.,  2.]])
print(C.shape)
# torch.Size([6, 3])


# 按维数1(列)拼接 A 和 D
C=torch.cat((A,D),1)
# tensor([[ 1.,  1.,  1.,  2.,  2.,  2.,  2.],
#         [ 1.,  1.,  1.,  2.,  2.,  2.,  2.]])
print(C.shape)
# torch.Size([2, 7])

另外,torch.cat((A,B),axis)还能把list中的tensor拼接起来。

import torch

x = torch.Tensor([1, 2, 3])
x = x.unsqueeze(1)
x2 = torch.cat( [ x*2 for i in range (1,4) ], 1 )
# tensor([[2., 2., 2.],
#         [4., 4., 4.],
#         [6., 6., 6.]])

x = torch.Tensor([1, 2, 3])生成的xshapetorch.Size([3]),我们需要用x = x.unsqueeze(1)x增加第二个维度,使其变为二维的tensortorch.Size([3, 1])

关于升维函数 x.unsqueeze(axis) 和降维函数 x.unsqueeze(axis) 的详细说明,可以去我的另一篇博客增加维度或者减少维度 ——a.squeeze(axis) 和 a.unsqueeze(axis)

torch.cat( [ x*2 for i in range (1,4) ], 1 )先生成了一个包含 3tensorlist,然后对list中的元素按列拼接(axis=1)。
关于range这个基本的函数,见本人的另一篇博客创建一个整数列表—— range()
故最后的结果为:

# tensor([[2., 2., 2.],
#         [4., 4., 4.],
#         [6., 6., 6.]])

你可能感兴趣的:(PyTorch)