PyTorch: Conv2d and ConvTranspose2d

ConvTranspose2d 实现的是 Conv2d 的逆过程,也就是将一张 m × m m \times m m×m 的图片,upsampling 到 n × n n \times n n×n,这里 n > m n > m n>m。 ConvTranspose2d 的实现方法,与 Assignment 2 | ConvolutionalNetworks 计算 dx 的方法完全相同。实际上,不论在 PyTorch 还是在 TensorFlow 里面,ConvTranspose2d 的实现和计算 dx 的梯度的实现,使用的是同一段代码。在 PyTorch 的文档里明确说明了这一点:

This module can be seen as the gradient of Conv2d with respect to its input.

这里先把 Conv2d 中计算 dx 的方法写一下:

dx 的计算方法

这里写出 dx 和 dw 的闭式解很复杂,而且不容易写出代码,所以这里用一个例子来推出 dx 和 dw 的计算过程,根据此计算过程可以将代码写出。这里 stride = 1,pad = 0,x,w,y为:
x = [ x 11 x 12 x 13 x 21 x 22 x 23 x 31 x 32 x 33 ] , w = [ w 11 w 12 w 21 w 22 ] , y = [ y 11 y 12 y 21 y 22 ] x = \begin{bmatrix} x_{11}& x_{12}& x_{13}\newline x_{21}& x_{22}& x_{23}\newline x_{31}& x_{32}& x_{33} \end{bmatrix}, \quad w = \begin{bmatrix} w_{11}& w_{12}\newline w_{21}& w_{22} \end{bmatrix}, \quad y = \begin{bmatrix} y_{11}& y_{12}\newline y_{21}& y_{22} \end{bmatrix} x=[x11x12x13x21x22x23x31x32x33],w=[w11w12w21w22],y=[y11y12y21y22]
y = x × w y = x \times w y=x×w 展开:
y 11 = w 11 x 11 + w 12 x 12 + w 21 x 21 + w 22 x 22 y 12 = w 11 x 12 + w 12 x 13 + w 21 x 22 + w 22 x 23 y 21 = w 11 x 21 + w 12 x 22 + w 21 x 31 + w 22 x 32 y 22 = w 11 x 22 + w 12 x 23 + w 21 x 32 + w 22 x 33 \begin{aligned} y_{11} &= w_{11}x_{11} + w_{12}x_{12} + w_{21}x_{21} + w_{22}x_{22} \newline y_{12} &= w_{11}x_{12} + w_{12}x_{13} + w_{21}x_{22} + w_{22}x_{23} \newline y_{21} &= w_{11}x_{21} + w_{12}x_{22} + w_{21}x_{31} + w_{22}x_{32} \newline y_{22} &= w_{11}x_{22} + w_{12}x_{23} + w_{21}x_{32} + w_{22}x_{33} \newline \end{aligned} y11=w11x11+w12x12+w21x21+w22x22y12=w11x12+w12x13+w21x22+w22x23y21=w11x21+w12x22+w21x31+w22x32y22=w11x22+w12x23+w21x32+w22x33
所以:
d x = [ ∂ L ∂ y ∂ y ∂ x 11 ∂ L ∂ y ∂ y ∂ x 12 ∂ L ∂ y ∂ y ∂ x 13 ∂ L ∂ y ∂ y ∂ x 21 ∂ L ∂ y ∂ y ∂ x 22 ∂ L ∂ y ∂ y ∂ x 23 ∂ L ∂ y ∂ y ∂ x 31 ∂ L ∂ y ∂ y ∂ x 32 ∂ L ∂ y ∂ y ∂ x 33 ] \mathrm{d} x = \begin{bmatrix} \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{11}}& \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{12}}& \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{13}}\newline \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{21}}& \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{22}}& \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{23}}\newline \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{31}}& \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{32}}& \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{33}} \end{bmatrix} dx=[yLx11yyLx12yyLx13yyLx21yyLx22yyLx23yyLx31yyLx32yyLx33y]
x 11 x_{11} x11 相关的仅有 y 11 y_{11} y11,所以第一项 ∂ L ∂ y ∂ y ∂ x 11 = ∂ y 11 ⋅ w 11 \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{11}} = \partial y_{11} \cdot w_{11} yLx11y=y11w11,与 x 12 x_{12} x12 相关的有两项 y 11 y_{11} y11 y 12 y_{12} y12,所以第二项 ∂ L ∂ y ∂ y ∂ x 12 = ∂ y 11 ⋅ w 12 + ∂ y 12 ⋅ w 11 \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{12}} = \partial y_{11} \cdot w_{12} + \partial y_{12} \cdot w_{11} yLx12y=y11w12+y12w11,依次类推,将最后结果写成如下形式就一目了然了:
d x = [ ∂ y 11 ⋅ w 11 ∂ y 11 ⋅ w 12 + ∂ y 12 ⋅ w 11 ∂ y 12 ⋅ w 12 ∂ y 11 ⋅ w 21 + ∂ y 11 ⋅ w 22 + ∂ y 12 ⋅ w 21 + ∂ y 12 ⋅ w 22 ∂ y 21 ⋅ w 11 ∂ y 21 ⋅ w 12 + ∂ y 22 ⋅ w 11 ∂ y 22 ⋅ w 12 ∂ y 21 ⋅ w 21 ∂ y 21 ⋅ w 22 + ∂ y 22 ⋅ w 21 ∂ y 22 ⋅ w 22 ] \mathrm{d} x = \begin{bmatrix} \partial y_{11} \cdot w_{11}& \partial y_{11} \cdot w_{12} + & \newline &\partial y_{12} \cdot w_{11} & \partial y_{12} \cdot w_{12} \newline & & \newline \partial y_{11} \cdot w_{21} + & \partial y_{11} \cdot w_{22} + & \newline & \partial y_{12} \cdot w_{21} + & \partial y_{12} \cdot w_{22} \newline \partial y_{21} \cdot w_{11} & \partial y_{21} \cdot w_{12} + & \newline & \partial y_{22} \cdot w_{11} & \partial y_{22} \cdot w_{12} \newline & & \newline \partial y_{21} \cdot w_{21} & \partial y_{21} \cdot w_{22} + & \newline & \partial y_{22} \cdot w_{21} & \partial y_{22} \cdot w_{22} \newline \end{bmatrix} dx=[y11w11y11w12+y12w11y12w12y11w21+y11w22+y12w21+y12w22y21w11y21w12+y22w11y22w12y21w21y21w22+y22w21y22w22]
显然,dx的计算方法是在一个形如 x 的矩阵上滑动,先计算
∂ y 11 ⋅ [ w 11 w 12 w 21 w 22 ] \partial y_{11} \cdot \begin{bmatrix} w_{11}& w_{12}\newline w_{21}& w_{22} \end{bmatrix} y11[w11w12w21w22]
并将结果放在 dx 的第一个形如 w 的块上,然后计算
∂ y 12 ⋅ [ w 11 w 12 w 21 w 22 ] \partial y_{12} \cdot \begin{bmatrix} w_{11}& w_{12}\newline w_{21}& w_{22} \end{bmatrix} y12[w11w12w21w22]
滑动 stride,并将结果放在 dx 的第二个形如 w 的块上,依次类推。

这里需要注意的是:

  1. 例子里的 pad = 0。如果 pad 不为0的话,所有对 x 的计算都要针对扩充后的 x_pad,得到的结果也是 dx_pad,最后返回的结果 dx 要将 dx_pad 去掉 pad。
  2. 滑动的次数由 dout 的形状决定,滑动的步长由 stride 决定。

以上仅是针对 x 的最后两个维度的计算,前两个维度加循环即可

for i in range(N):
    for oc in range(K):
        for ww in range(out_w):
            for hh in range(out_h):
                dpad_x[i, :, (s*hh):(s*hh+f_h), (s*ww):(s*ww+f_w)] += dout[i, oc, hh, ww] * w[oc, ...]

dx = dpad_x[:, :, p:(in_h+p), p:(in_w+p)]

实际上,我用 PyTorch 的 torch.nn.functional.conv_transpose2d 实现了一下计算 dx 的梯度,得到的结果是一样的:

dout_tensor = torch.from_numpy(dout)
w_tensor = torch.from_numpy(w)
dx = torch.nn.functional.conv_transpose2d(dout_tensor, w_tensor, bias=None, stride=s, padding=p)
dx = dx.numpy()

ConvTranspose2d 中的参数设置

ConvTranspose2d 中的参数设置,特别是 padding 和 函数的输出形状破费一些思量。我的建议是首先找出其对偶的 Conv2d,此 Conv2d 中的参数就是 ConvTranspose2d 中的参数,除了输入输出形状互换以外。通过下面这个例子详细说一下。

作业中在实现 CNN GAN 时,给出的 Generator 实现为:

Reshape into Image Tensor of shape 7, 7, 128

Conv2D^T (Transpose): 64 filters of 4x4, stride 2, ‘same’ padding

ReLU
BatchNorm
Conv2D^T (Transpose): 1 filter of 4x4, stride 2, ‘same’ padding

第一个 ConvTranspose2d 的输入形状显然是 (7, 7),有 128 个 channel,ConvTranspose2d 的 filter 形状是 (4, 4),stride = 2,使用 ‘SAME’ padding。那么,ConvTranspose2d 的输出形状到底是多少呢?这里的 ‘SAME’ padding 到底是 pad 多少 0 呢?

首先,‘SAME’ padding 的说法显然是来自 TensorFlow 而非 PyTorch,所以,我们看 TensorFlow 的 tf.nn.conv2d_transpose 手册。里面在讲 padding 的时候,直接链接到了 tf.nn.convolution 中关于 padding 的说明:

If padding == “SAME”: output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])

If padding == “VALID”: output_spatial_shape[i] = ceil((input_spatial_shape[i] - (spatial_filter_shape[i]-1) * dilation_rate[i]) / strides[i]).

注意,这里的 input 和 output 是针对 Conv2d 而言的,它正好和 ConvTranspose2d 是相反的!所以,如果使用 ‘SAME’ 的话,ConvTranspose2d 的 output 应该是 output = input * stride。因为 Conv2d 在计算时有一个 ceil 在算式里,所以,ConvTranspose2d 的 output 大小不是唯一的。以作业为例,ConvTranspose2d 的输入形如 (7, 7),stride = 2,那么其对偶的 Conv2d 的输入大小可以是 7 * 2 = 14,也可以是 7 * 2 + 1 = 15。因为 Conv2d 计算输出大小的公式是 ceil ( input / stride )。这也是为什么在 PyTorch 版本的 ConvTranspose2d 中还要额外给一个 output_padding 的参数,而且还有一个 Note 说:

However, when :attrstride >1, Conv2d maps multiple input shapes to the same output shape. output_padding is provided to resolve this ambiguity by effectively increasing the calculated output shape on one side. Note that output_padding is only used to find output shape, but does not actually add zero-padding to output.

到这里,我们得到了对偶的 Conv2d 的输入是 (14, 14) 或者 (15, 15),输出是 (7, 7),kernel 是 (4, 4),stride = 2,那么 padding 就可以计算出来了:

  • 如果输入是取 (14, 14) 的话,(14 - 4 + 2 * padding) / 2 + 1 = 7,此时的 padding 是 1.
  • 如果输入是取 (15, 15) 的话,(15 - 4 + 2 * padding) / 2 + 1 = 7,那么 padding 是 0.5。

这两个结果没有对错之分,只不过我们取一个偶数的值是比较好的,所以这里取 14。到此为止,Conv2d 的所有参数都已经确定,即

torch.nn.Conv2d(128, 64, (4, 4), stride=2, padding=1)

它的输入是形如 (batch_size, 128, 14, 14),输出是形如 (batch_size, 64, 7, 7) 的。

因此,其对偶的 ConvTranspose2d 的参数与其完全相同,即为:

torch.nn.ConvTranspose2d(128, 64, (4, 4), stride=2, padding=1)

它的输入是形如 (batch_size, 128, 7, 7),输出是形如 (batch_size, 64, 14, 14) 的。从而将一张 7 * 7 的图片 upsampling 到了 14 * 14。

第二个 ConvTranspose2d 层用同样方法计算

Conv2D^T (Transpose): 1 filter of 4x4, stride 2, ‘same’ padding

Conv2d 的输出为 14 * 14 的图片,采用 stride 2,‘same’ padding,所以 Conv2d 的输入为28或者29,这里取28。 需要 padding = 1,所以

torch.nn.Conv2d(64, 1, (4, 4), stride=2, padding=1)

那么其对偶的 ConvTranspose2d 为:

torch.nn.ConvTranspose2d(64, 1, (4, 4), stride=2, padding=1)

你可能感兴趣的:(PyTorch)