Pytorch:torch.cat()

1 torch.cat()

在给定维度上对输入的张量序列seq 进行连接操作。torch.cat()可以看做 torch.split() 和 torch.chunk()的反操作。 cat() 函数可以通过下面例子更好的理解。

torch.cat(seq, dim=0, out=None) → Tensor
参数 描述
seq(Tensors的序列) 可以是相同类型的Tensor的任何python序列
dim(int,可选) 沿着此维连接张量序列
out(Tensor,可选) 输出参数

2 实例

  • 栗子1
>>> import torch
>>> A=torch.ones(2,3) #2x3的张量(矩阵)                                     
>>> A
tensor([[ 1.,  1.,  1.],
        [ 1.,  1.,  1.]])
>>> B=2*torch.ones(4,3)#4x3的张量(矩阵)                                    
>>> B
tensor([[ 2.,  2.,  2.],
        [ 2.,  2.,  2.],
        [ 2.,  2.,  2.],
        [ 2.,  2.,  2.]])
>>> C=torch.cat((A,B),0)#按维数0(行)拼接
>>> C
tensor([[ 1.,  1.,  1.],
         [ 1.,  1.,  1.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.]])
>>> C.size()
torch.Size([6, 3])
>>> D=2*torch.ones(2,4) #2x4的张量(矩阵)
>>> C=torch.cat((A,D),1)#按维数1(列)拼接
>>> C
tensor([[ 1.,  1.,  1.,  2.,  2.,  2.,  2.],
        [ 1.,  1.,  1.,  2.,  2.,  2.,  2.]])
>>> C.size()
torch.Size([2, 7])

上面给出了两个张量A和B,分别是2行3列,4行3列。即他们都是2维张量。因为只有两维,这样在用torch.cat拼接的时候就有两种拼接方式:按行拼接和按列拼接。即所谓的维数0和维数1. 

C=torch.cat((A,B),0)就表示按维数0(行)拼接A和B,也就是竖着拼接,A上B下。此时需要注意:列数必须一致,即维数1数值要相同,这里都是3列,方能列对齐。拼接后的C的第0维是两个维数0数值和,即2+4=6.

C=torch.cat((A,B),1)就表示按维数1(列)拼接A和B,也就是横着拼接,A左B右。此时需要注意:行数必须一致,即维数0数值要相同,这里都是2行,方能行对齐。拼接后的C的第1维是两个维数1数值和,即3+4=7.

从2维例子可以看出,使用torch.cat((A,B),dim)时,除拼接维数dim数值可不同外其余维数数值需相同,方能对齐
 

  • 栗子2:高维
import torch
a = torch.randn(2,3,4)
a
Out[4]: 
tensor([[[ 0.1728,  1.2299,  1.1025,  1.2469],
         [-1.7841, -0.0682, -1.3842,  0.1034],
         [-0.2922,  0.5537,  1.8052,  1.6766]],
        [[ 0.6606, -0.1216,  0.6336, -0.9340],
         [-0.3626, -1.1162,  0.8975, -0.1320],
         [-0.2133,  0.3769,  0.5940, -1.3333]]])
b = torch.cat((a,a),dim=0)
b
Out[6]: 
tensor([[[ 0.1728,  1.2299,  1.1025,  1.2469],
         [-1.7841, -0.0682, -1.3842,  0.1034],
         [-0.2922,  0.5537,  1.8052,  1.6766]],
        [[ 0.6606, -0.1216,  0.6336, -0.9340],
         [-0.3626, -1.1162,  0.8975, -0.1320],
         [-0.2133,  0.3769,  0.5940, -1.3333]],
        [[ 0.1728,  1.2299,  1.1025,  1.2469],
         [-1.7841, -0.0682, -1.3842,  0.1034],
         [-0.2922,  0.5537,  1.8052,  1.6766]],
        [[ 0.6606, -0.1216,  0.6336, -0.9340],
         [-0.3626, -1.1162,  0.8975, -0.1320],
         [-0.2133,  0.3769,  0.5940, -1.3333]]])
b.size()
Out[7]: torch.Size([4, 3, 4])


3 深度学习中使用案例

在深度学习处理图像时,常用的有3通道的RGB彩色图像及单通道的灰度图。张量size为cxhxw,即通道数x图像高度x图像宽度。在用torch.cat拼接两张图像时一般要求图像大小一致而通道数可不一致,即h和w同,c可不同。当然实际有3种拼接方式,另两种好像不常见。比如经典网络结构:U-Net:

里面用到4次torch.cat,其中copy and crop操作就是通过torch.cat来实现的。可以看到通过上采样(up-conv 2x2)将原始图像h和w变为原来2倍,再和左边直接copy过来的同样h,w的图像拼接。这样做,可以有效利用原始结构信息。

参考:

  • https://blog.csdn.net/weixin_43914889/article/details/104616034
  • https://blog.csdn.net/qq_39709535/article/details/80803003

 

你可能感兴趣的:(修仙之路:pytorch篇)