ConvTranspose2d 的简单例子理解

文章目录

    • 参考
    • 基础概念
      • output_padding
    • 简单例子: stride=2
        • step1
        • step2
        • step3

参考

  • 逆卷积的详细解释ConvTranspose2d(fractionally-strided convolutions)
  • nn.ConvTranspose2d的参数output_padding的作用
  • torch.nn.ConvTranspose2d Explained

基础概念

逆卷积,也叫反卷积或者转置卷积,作用是对图像进行上采样。
参考链接中的文章对形状变化的公式做了较为详细的描述,这里简单使用一个例子来演示数据变换过程。

假如输入数据(也叫原始数据)形状为 H x W, H = W.
大概的流程可以描述为

  • 当stride >1, 对原始数据进行插值变换,也就是每相邻数据间插入 (s-1) 列/行数据,变换后数据形状为
    H n e w = W n e w = H + ( s − 1 ) ∗ ( H − 1 ) H_{new} = W_{new} = H + (s-1)*(H-1) Hnew=Wnew=H+(s1)(H1)
  • 对变换后的数据进行padding,
    p a d d i n g n e w = k e r n e l _ s i z e − p a d d i n g − 1 padding_{new} = kernel\_size - padding - 1 paddingnew=kernel_sizepadding1
    经过这一步之后,数据形状为
    H n e w + 2 ∗ p a d d i n g n e w H_{new} + 2* padding_{new} Hnew+2paddingnew
  • 在这两步变换后进行如下正常卷积运算
    kernel_size = kernel_size
    padding = 0
    stride = 1

output_padding

正常卷积在运算时,会由于一些取整操作导致输入图像不同,但是最后生成的图像相同。
而反卷积在进行逆运算时,也会出现类似的情况。理想情况下希望 输出的图像尺寸/输入图像尺寸 = stride。
要完全达成这个目的,就通过在最后的形状上的一边添加 output_padding 完成。

简单例子: stride=2

h = w = 2
stride = 2
p = 0
kernel_size = 3
step1

原始数据:
ConvTranspose2d 的简单例子理解_第1张图片

step2

内部变换:
stride >1,
当卷积时设置的stride>1时,将对输入的特征图y进行插值操作(interpolation)。

即需要在输入的特征图y的每个相邻值之间插入(stride-1)行和列0,因为特征图中能够插入的相邻位置有(height-1)个位置,所以此时得到的特征图的大小由 H o u t × H o u t Hout \times Hout Hout×Hout(Hout即height) 变为新的 H o u t n e w × H o u t n e w Hout_new\times Hout_new Houtnew×Houtnew,即 [ H o u t + ( s t r i d e − 1 ) × ( H o u t − 1 ) ] × [ H o u t + ( s t r i d e − 1 ) × ( H o u t − 1 ) ] [Hout + (stride-1) \times (Hout-1)] \times [Hout + (stride-1) \times (Hout-1)] [Hout+(stride1)×(Hout1)]×[Hout+(stride1)×(Hout1)]
ConvTranspose2d 的简单例子理解_第2张图片

step3

外部变换:
为了实现由 H o u t × H o u t Hout \times Hout Hout×Hout 大小的y逆卷积得到 H i n × H i n Hin \times Hin Hin×Hin大小的x,还需要设置padding_new的值为(kernel_size - padding - 1),这里的padding是卷积操作时设置的padding值

ConvTranspose2d 的简单例子理解_第3张图片
然后在这个变换好的图上进行 kernel_size=3, padding =0, stride =1 的正常卷积可以得到最终结果
按照卷积的变换公式计算得到

(w + 2p - kernel_size)/s +1 = (7 - 3) +1 = 5.

也就是得到 5x5 的数据。

ConvTranspose2d 的简单例子理解_第4张图片
ConvTranspose2d 的简单例子理解_第5张图片
ConvTranspose2d 的简单例子理解_第6张图片

你可能感兴趣的:(python,深度学习)