pytorch中torch.cat()函数理解

pytorch中torch.cat()函数:

功能:拼接两个tensor。

用法:把两个tensor A和B拼接在一起,可进行如下操作:

C = torch.cat( (A,B),0 )  #按维数0拼接(竖着拼)

C = torch.cat( (A,B),1 )  #按维数1拼接(横着拼)

示例说明

1)按维数0拼接

>>> import torch

>>> A=torch.ones(2,3)    #2x3的张量(2行3列的矩阵)                                    

>>> A

tensor([[ 1.,  1.,  1.],

        [ 1.,  1.,  1.]])

>>> B=2*torch.ones(4,3)  #4x3的张量(4行3列的矩阵)                                   

>>> B

tensor([[ 2.,  2.,  2.],

        [ 2.,  2.,  2.],

        [ 2.,  2.,  2.],

        [ 2.,  2.,  2.]])

>>> C=torch.cat((A,B),0)  #按维数0(行)拼接

>>> C

tensor([[ 1.,  1.,  1.],

         [ 1.,  1.,  1.],

         [ 2.,  2.,  2.],

         [ 2.,  2.,  2.],

         [ 2.,  2.,  2.],

         [ 2.,  2.,  2.]])

>>> C.size()

torch.Size([6, 3])

 

2)按维数1拼接

>>> D=2*torch.ones(2,4) #2x4的张量(2行4列的矩阵)

>>> C=torch.cat((A,D),1)#按维数1(列)拼接

>>> C

tensor([[ 1.,  1.,  1.,  2.,  2.,  2.,  2.],

        [ 1.,  1.,  1.,  2.,  2.,  2.,  2.]])

>>> C.size()

torch.Size([2, 7])

你可能感兴趣的:(python)