torch.cat()用法详解

先把官方解释展示一下

torch.cat()用法详解_第1张图片

torch.cat(tensor,dim=0) 

第一个参数是tensor,第二个是轴,-1,0,1,2.....,这取决于你的tensor是几维空间。

它的功能是将多个tensor类型矩阵的连接。它有两个参数,
第一个是tensor元组或者tensor列表;
第二个是dim,如果tensor是二维的,dim=0指在行上连接,dim=1指在列上连接。
但是注意这里在行上连接,是扩展行进行连接,在列上连接是扩展列连接。
注意:torch.cat 进行连接的tensor的shape,除了需要连接的维度上的shape值可不同,
必须拥有相同的shape,a是(2,3),b是(2,20)即torch.cat((a,b),-1)可以进行连接;torch.cat((a,b),0)不可以进行连接,因为3和20值不同

它的功能是将多个tensor类型矩阵的连接。它有两个参数,

第一个是tensor元组或者tensor列表;

第二个是dim,如果tensor是二维的,dim=0指在行上连接,dim=1指在列上连接。

但是注意这里在行上连接,是扩展行进行连接,在列上连接是扩展列连接。
注意:torch.cat 进行连接的tensor的shape,除了需要连接的维度上的shape值可不同,必须拥有相同的shape,a是(2,3),b是(2,20)即torch.cat((a,b),-1)可以进行连接;

torch.cat((a,b),0)不可以进行连接,因为3和20值不同

例子一:

    a=torch.randn(2,3)
    print(a)
    b=torch.cat((a,a,a),1)
    print(b)
    
'''
#输出结果,dim=1,可见扩展列了
tensor([[-0.1121, -0.2641,  0.4476],
        [-1.2637,  1.0789,  1.0342]])
tensor([[-0.1121, -0.2641,  0.4476, -0.1121, -0.2641,  0.4476, -0.1121, -0.2641,
          0.4476],
        [-1.2637,  1.0789,  1.0342, -1.2637,  1.0789,  1.0342, -1.2637,  1.0789,
          1.0342]])
'''

例子二 ;看看为啥出现错误(这个是Ipython 写的),代码最后,因为维度的shape不一样,报错了
 

In [11]: import torch as tr
 
In [12]: A=tr.ones(2,3)
 
In [13]: A
Out[13]: 
tensor([[1., 1., 1.],
        [1., 1., 1.]])
 
In [14]: B=2*tr.ones(4,3)
 
In [15]: B
Out[15]: 
tensor([[2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]])
 
In [16]: C=tr.cat((A,B),0)
 
In [17]: C
Out[17]: 
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.],
        [2., 2., 2.]])
 
In [18]: C=tr.cat((A,B),1)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [18], in ()
----> 1 C=tr.cat((A,B),1)
 
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 2 but got size 4 for tensor number 1 in the list.

你可能感兴趣的:(pytorch,深度学习,python,pytorch)