ConvTranspose2d 实现的是 Conv2d 的逆过程,也就是将一张 m × m m \times m m×m 的图片,upsampling 到 n × n n \times n n×n,这里 n > m n > m n>m。 ConvTranspose2d 的实现方法,与 Assignment 2 | ConvolutionalNetworks 计算 dx 的方法完全相同。实际上,不论在 PyTorch 还是在 TensorFlow 里面,ConvTranspose2d 的实现和计算 dx 的梯度的实现,使用的是同一段代码。在 PyTorch 的文档里明确说明了这一点:
This module can be seen as the gradient of Conv2d with respect to its input.
这里先把 Conv2d 中计算 dx 的方法写一下:
这里写出 dx 和 dw 的闭式解很复杂,而且不容易写出代码,所以这里用一个例子来推出 dx 和 dw 的计算过程,根据此计算过程可以将代码写出。这里 stride = 1,pad = 0,x,w,y为:
x = [ x 11 x 12 x 13 x 21 x 22 x 23 x 31 x 32 x 33 ] , w = [ w 11 w 12 w 21 w 22 ] , y = [ y 11 y 12 y 21 y 22 ] x = \begin{bmatrix} x_{11}& x_{12}& x_{13}\newline x_{21}& x_{22}& x_{23}\newline x_{31}& x_{32}& x_{33} \end{bmatrix}, \quad w = \begin{bmatrix} w_{11}& w_{12}\newline w_{21}& w_{22} \end{bmatrix}, \quad y = \begin{bmatrix} y_{11}& y_{12}\newline y_{21}& y_{22} \end{bmatrix} x=[x11x12x13x21x22x23x31x32x33],w=[w11w12w21w22],y=[y11y12y21y22]
将 y = x × w y = x \times w y=x×w 展开:
y 11 = w 11 x 11 + w 12 x 12 + w 21 x 21 + w 22 x 22 y 12 = w 11 x 12 + w 12 x 13 + w 21 x 22 + w 22 x 23 y 21 = w 11 x 21 + w 12 x 22 + w 21 x 31 + w 22 x 32 y 22 = w 11 x 22 + w 12 x 23 + w 21 x 32 + w 22 x 33 \begin{aligned} y_{11} &= w_{11}x_{11} + w_{12}x_{12} + w_{21}x_{21} + w_{22}x_{22} \newline y_{12} &= w_{11}x_{12} + w_{12}x_{13} + w_{21}x_{22} + w_{22}x_{23} \newline y_{21} &= w_{11}x_{21} + w_{12}x_{22} + w_{21}x_{31} + w_{22}x_{32} \newline y_{22} &= w_{11}x_{22} + w_{12}x_{23} + w_{21}x_{32} + w_{22}x_{33} \newline \end{aligned} y11=w11x11+w12x12+w21x21+w22x22y12=w11x12+w12x13+w21x22+w22x23y21=w11x21+w12x22+w21x31+w22x32y22=w11x22+w12x23+w21x32+w22x33
所以:
d x = [ ∂ L ∂ y ∂ y ∂ x 11 ∂ L ∂ y ∂ y ∂ x 12 ∂ L ∂ y ∂ y ∂ x 13 ∂ L ∂ y ∂ y ∂ x 21 ∂ L ∂ y ∂ y ∂ x 22 ∂ L ∂ y ∂ y ∂ x 23 ∂ L ∂ y ∂ y ∂ x 31 ∂ L ∂ y ∂ y ∂ x 32 ∂ L ∂ y ∂ y ∂ x 33 ] \mathrm{d} x = \begin{bmatrix} \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{11}}& \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{12}}& \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{13}}\newline \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{21}}& \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{22}}& \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{23}}\newline \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{31}}& \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{32}}& \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{33}} \end{bmatrix} dx=[∂y∂L∂x11∂y∂y∂L∂x12∂y∂y∂L∂x13∂y∂y∂L∂x21∂y∂y∂L∂x22∂y∂y∂L∂x23∂y∂y∂L∂x31∂y∂y∂L∂x32∂y∂y∂L∂x33∂y]
与 x 11 x_{11} x11 相关的仅有 y 11 y_{11} y11,所以第一项 ∂ L ∂ y ∂ y ∂ x 11 = ∂ y 11 ⋅ w 11 \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{11}} = \partial y_{11} \cdot w_{11} ∂y∂L∂x11∂y=∂y11⋅w11,与 x 12 x_{12} x12 相关的有两项 y 11 y_{11} y11 和 y 12 y_{12} y12,所以第二项 ∂ L ∂ y ∂ y ∂ x 12 = ∂ y 11 ⋅ w 12 + ∂ y 12 ⋅ w 11 \frac{\partial L}{\partial y}\frac{\partial y}{\partial x_{12}} = \partial y_{11} \cdot w_{12} + \partial y_{12} \cdot w_{11} ∂y∂L∂x12∂y=∂y11⋅w12+∂y12⋅w11,依次类推,将最后结果写成如下形式就一目了然了:
d x = [ ∂ y 11 ⋅ w 11 ∂ y 11 ⋅ w 12 + ∂ y 12 ⋅ w 11 ∂ y 12 ⋅ w 12 ∂ y 11 ⋅ w 21 + ∂ y 11 ⋅ w 22 + ∂ y 12 ⋅ w 21 + ∂ y 12 ⋅ w 22 ∂ y 21 ⋅ w 11 ∂ y 21 ⋅ w 12 + ∂ y 22 ⋅ w 11 ∂ y 22 ⋅ w 12 ∂ y 21 ⋅ w 21 ∂ y 21 ⋅ w 22 + ∂ y 22 ⋅ w 21 ∂ y 22 ⋅ w 22 ] \mathrm{d} x = \begin{bmatrix} \partial y_{11} \cdot w_{11}& \partial y_{11} \cdot w_{12} + & \newline &\partial y_{12} \cdot w_{11} & \partial y_{12} \cdot w_{12} \newline & & \newline \partial y_{11} \cdot w_{21} + & \partial y_{11} \cdot w_{22} + & \newline & \partial y_{12} \cdot w_{21} + & \partial y_{12} \cdot w_{22} \newline \partial y_{21} \cdot w_{11} & \partial y_{21} \cdot w_{12} + & \newline & \partial y_{22} \cdot w_{11} & \partial y_{22} \cdot w_{12} \newline & & \newline \partial y_{21} \cdot w_{21} & \partial y_{21} \cdot w_{22} + & \newline & \partial y_{22} \cdot w_{21} & \partial y_{22} \cdot w_{22} \newline \end{bmatrix} dx=[∂y11⋅w11∂y11⋅w12+∂y12⋅w11∂y12⋅w12∂y11⋅w21+∂y11⋅w22+∂y12⋅w21+∂y12⋅w22∂y21⋅w11∂y21⋅w12+∂y22⋅w11∂y22⋅w12∂y21⋅w21∂y21⋅w22+∂y22⋅w21∂y22⋅w22]
显然,dx的计算方法是在一个形如 x 的矩阵上滑动,先计算
∂ y 11 ⋅ [ w 11 w 12 w 21 w 22 ] \partial y_{11} \cdot \begin{bmatrix} w_{11}& w_{12}\newline w_{21}& w_{22} \end{bmatrix} ∂y11⋅[w11w12w21w22]
并将结果放在 dx 的第一个形如 w 的块上,然后计算
∂ y 12 ⋅ [ w 11 w 12 w 21 w 22 ] \partial y_{12} \cdot \begin{bmatrix} w_{11}& w_{12}\newline w_{21}& w_{22} \end{bmatrix} ∂y12⋅[w11w12w21w22]
滑动 stride,并将结果放在 dx 的第二个形如 w 的块上,依次类推。
这里需要注意的是:
以上仅是针对 x 的最后两个维度的计算,前两个维度加循环即可
for i in range(N):
for oc in range(K):
for ww in range(out_w):
for hh in range(out_h):
dpad_x[i, :, (s*hh):(s*hh+f_h), (s*ww):(s*ww+f_w)] += dout[i, oc, hh, ww] * w[oc, ...]
dx = dpad_x[:, :, p:(in_h+p), p:(in_w+p)]
实际上,我用 PyTorch 的 torch.nn.functional.conv_transpose2d 实现了一下计算 dx 的梯度,得到的结果是一样的:
dout_tensor = torch.from_numpy(dout)
w_tensor = torch.from_numpy(w)
dx = torch.nn.functional.conv_transpose2d(dout_tensor, w_tensor, bias=None, stride=s, padding=p)
dx = dx.numpy()
ConvTranspose2d 中的参数设置,特别是 padding 和 函数的输出形状破费一些思量。我的建议是首先找出其对偶的 Conv2d,此 Conv2d 中的参数就是 ConvTranspose2d 中的参数,除了输入输出形状互换以外。通过下面这个例子详细说一下。
作业中在实现 CNN GAN 时,给出的 Generator 实现为:
Reshape into Image Tensor of shape 7, 7, 128
Conv2D^T (Transpose): 64 filters of 4x4, stride 2, ‘same’ padding
ReLU
BatchNorm
Conv2D^T (Transpose): 1 filter of 4x4, stride 2, ‘same’ padding
第一个 ConvTranspose2d 的输入形状显然是 (7, 7),有 128 个 channel,ConvTranspose2d 的 filter 形状是 (4, 4),stride = 2,使用 ‘SAME’ padding。那么,ConvTranspose2d 的输出形状到底是多少呢?这里的 ‘SAME’ padding 到底是 pad 多少 0 呢?
首先,‘SAME’ padding 的说法显然是来自 TensorFlow 而非 PyTorch,所以,我们看 TensorFlow 的 tf.nn.conv2d_transpose 手册。里面在讲 padding 的时候,直接链接到了 tf.nn.convolution 中关于 padding 的说明:
If padding == “SAME”: output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides[i])
If padding == “VALID”: output_spatial_shape[i] = ceil((input_spatial_shape[i] - (spatial_filter_shape[i]-1) * dilation_rate[i]) / strides[i]).
注意,这里的 input 和 output 是针对 Conv2d 而言的,它正好和 ConvTranspose2d 是相反的!所以,如果使用 ‘SAME’ 的话,ConvTranspose2d 的 output 应该是 output = input * stride。因为 Conv2d 在计算时有一个 ceil 在算式里,所以,ConvTranspose2d 的 output 大小不是唯一的。以作业为例,ConvTranspose2d 的输入形如 (7, 7),stride = 2,那么其对偶的 Conv2d 的输入大小可以是 7 * 2 = 14,也可以是 7 * 2 + 1 = 15。因为 Conv2d 计算输出大小的公式是 ceil ( input / stride )。这也是为什么在 PyTorch 版本的 ConvTranspose2d 中还要额外给一个 output_padding 的参数,而且还有一个 Note 说:
However, when :attr
stride
>1, Conv2d maps multiple input shapes to the same output shape. output_padding is provided to resolve this ambiguity by effectively increasing the calculated output shape on one side. Note that output_padding is only used to find output shape, but does not actually add zero-padding to output.
到这里,我们得到了对偶的 Conv2d 的输入是 (14, 14) 或者 (15, 15),输出是 (7, 7),kernel 是 (4, 4),stride = 2,那么 padding 就可以计算出来了:
这两个结果没有对错之分,只不过我们取一个偶数的值是比较好的,所以这里取 14。到此为止,Conv2d 的所有参数都已经确定,即
torch.nn.Conv2d(128, 64, (4, 4), stride=2, padding=1)
它的输入是形如 (batch_size, 128, 14, 14),输出是形如 (batch_size, 64, 7, 7) 的。
因此,其对偶的 ConvTranspose2d 的参数与其完全相同,即为:
torch.nn.ConvTranspose2d(128, 64, (4, 4), stride=2, padding=1)
它的输入是形如 (batch_size, 128, 7, 7),输出是形如 (batch_size, 64, 14, 14) 的。从而将一张 7 * 7 的图片 upsampling 到了 14 * 14。
第二个 ConvTranspose2d 层用同样方法计算
Conv2D^T (Transpose): 1 filter of 4x4, stride 2, ‘same’ padding
Conv2d 的输出为 14 * 14 的图片,采用 stride 2,‘same’ padding,所以 Conv2d 的输入为28或者29,这里取28。 需要 padding = 1,所以
torch.nn.Conv2d(64, 1, (4, 4), stride=2, padding=1)
那么其对偶的 ConvTranspose2d 为:
torch.nn.ConvTranspose2d(64, 1, (4, 4), stride=2, padding=1)