Pytorch中torch.cat()

  1. torch.cat(tensors, dim=0)是将两个张量(tensor)拼接在一起,cat是concatenate的意思,即拼接。
    tensors:(或者tensors序列)——提供的非空tensor必须具有相同的shape,cat 维度除外。
    dim (int, optional): —— 拼接tensor的维度

  2. 例子:

>>> import torch
>>> A=torch.ones(2,3) #2x3的张量(矩阵)                                     
>>> A
tensor([[ 1.,  1.,  1.],
        [ 1.,  1.,  1.]])
>>> B=2*torch.ones(4,3)#4x3的张量(矩阵)                                    
>>> B
tensor([[ 2.,  2.,  2.],
        [ 2.,  2.,  2.],
        [ 2.,  2.,  2.],
        [ 2.,  2.,  2.]])
>>> C=torch.cat((A,B),0)#按维数0(行)拼接
>>> C
tensor([[ 1.,  1.,  1.],
         [ 1.,  1.,  1.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.]])
>>> C.size()
torch.Size([6, 3])
>>> D=2*torch.ones(2,4) #2x4的张量(矩阵)
>>> C=torch.cat((A,D),1)#按维数1(列)拼接
>>> C
tensor([[ 1.,  1.,  1.,  2.,  2.,  2.,  2.],
        [ 1.,  1.,  1.,  2.,  2.,  2.,  2.]])
>>> C.size()
torch.Size([2, 7])

总结:torch.cat((A,B),0),将A和B按改变维度0,不改变维度1的方向拼接 (即按行方向拼接,行数变而列数不变)
torch.cat((A,B),1),将A和B按改变维度1,不改变维度0的方向拼接 (即按列方向拼接,列数变而行数不变)

例子部分感谢my-GRIT大佬,链接:
https://blog.csdn.net/qq_39709535/article/details/80803003?spm=1001.2101.3001.6650.3&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7Edefault-3-80803003-blog-96428781.pc_relevant_aa&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7Edefault-3-80803003-blog-96428781.pc_relevant_aa&utm_relevant_index=5

你可能感兴趣的:(深度学习,pytorch,深度学习,python,人工智能)