转置卷积(Transposed Convolution)

在介绍UNet的时候,我们提到了转置卷积,在UNet右侧分支上,对特征上采样的其中一种实现方式即为转置卷积(另一种为双线性插值)。所以今天我们就来看看转置卷积的实现细节。

由于篇幅原因,本篇不展开太多,只讲核心实现。关于转置卷积(以及各种卷积的)详细实现,可以参考论文:A guide to convolution arithmetic for deep learning

转置卷积首先也是一种卷积操作,绝大部分转置卷积是为了实现上采样的目的。之所以使用转置卷积代替其他简单上采样算法,如最邻近插值、双线性插值等,是因为转置卷积有可以学习的参数,使得网络可以学到最优上采样方法。

转置卷积不是卷积的逆变换,而是一种保持矩阵元素对应关系的操作,由于是上采样,所以一般是1对多的关系,也就是输入矩阵的一个元素,对应输出矩阵相应位置上的多个元素。

下面两图展示了卷积和转置卷积的区别(蓝色为输入矩阵,绿色为输出矩阵)。

转置卷积(Transposed Convolution)_第1张图片转置卷积(Transposed Convolution)_第2张图片

接下来,我们就来介绍具体的实现。

1. 转置卷积的实现步骤

首先,先来定义一些变量。

  • 输入矩阵宽和高分别是:Win,Hin;
  • 输出矩阵宽和高分别为:Wout, Hout;
  • 步长为stride,简记为s;
  • 填充像素数padding,简记为p;
  • 卷积核大小kernel size,简记为k。

转置卷积的运算步骤为:

  • 在输入矩阵的行间和列间分别填充s-1个0;
  • 在输入矩阵的四周,分别填充k-p-1行(列)0;
  • 卷积操作。

那么我们可以得到输出矩阵的大小为(这里只考虑行和列上的padding size相同,且为正方形kernel):

  • Hout = (Hin - 1) x s - 2 x p + k
  • Wout = (Win - 1) x 2 - 2 x p + k

例如,看下面几种情况:

(1) s = 1, p = 0, k = 3

转置卷积(Transposed Convolution)_第3张图片

(2) s = 2, p = 0, k = 3

转置卷积(Transposed Convolution)_第4张图片

(3) s = 2, p = 1, k = 3

 转置卷积(Transposed Convolution)_第5张图片

更多转置卷积形式可参考: GitHub - vdumoulin/conv_arithmetic: A technical report on convolution arithmetic in the context of deep learning

2. PyTorch实现ConvTranspose2d 

 PyTorch中对转置卷积的实现为:

CLASS torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros', device=None, dtype=None)

各参数的解释如下:

转置卷积(Transposed Convolution)_第6张图片

用法举例如下:

# With square kernels and equal stride
m = nn.ConvTranspose2d(16, 33, 3, stride=2)
# non-square kernels and unequal stride and with padding
m = nn.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
input = torch.randn(20, 16, 50, 100)
output = m(input)
# exact output size can be also specified as an argument
input = torch.randn(1, 16, 12, 12)
downsample = nn.Conv2d(16, 16, 3, stride=2, padding=1)
upsample = nn.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
h = downsample(input)
h.size()
output = upsample(h, output_size=input.size())
output.size()

你可能感兴趣的:(深度学习,cnn,深度学习,人工智能)