首先看官方文档(ConvTranspose2d — PyTorch 1.13 documentation)给出的调用参数
torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1,
padding=0, output_padding=0, groups=1, bias=True, dilation=1,
padding_mode='zeros', device=None, dtype=None)
就几个关键参数说一下我个人理解
实现过程
下面说说我个人理解的实现过程,假设输入tensor大小3*3,卷积核大小为2*2,stride=2,padding=1,output_padding=0,dilation=2
一、首先定义几个参数以及计算方式
实际填充数 padding_ = dilation * (kernel-1) - padding
输入tensor元素间填充(默认填充0)fill = stride - 1
卷积核元素间填充(默认填充0)kernel_fill = dilation - 1
则本示例中padding_ = 2*(2-1)-1=1, fill=2-1=1, kernel_fill=2-1=1
二、对输入tensor以及卷积核kernel进行填充
输入tensor以及卷积核的变换如图(其中蓝色为输入元素,红色为padding_,绿色为元素间填充fill,紫色为卷积核元素kernel,橙色为卷积核元素间填充kernel_fill)
注意,ConvTranspose中权重矩阵需要旋转180度(原因在转置卷积正向推导过程中有,本文仅做实现及应用理解)
填充后得到input_ 和 kernel_
三、卷积
对填充后的输入进行普通卷积,输入input_大小7*7,kernel_大小3*3,其中padding恒为0,stride恒为1,如以下函数表示
output_ = F.conv2d(input_,kernel_,padding=0,stride=1)
获得的output_尺寸大小为5*5,与官方文档中给出的计算结果一致
官方文档计算输出尺寸的公式:
本示例中Hout=(3-1)*2 - 2*1 + 2*(2-1) + 0 + 1 = 5
四、带上数值检验
import torch.nn.functional as F
import torch
input = torch.tensor([[[[0.11, 0.12, 0.13],
[0.21, 0.22, 0.23],
[0.24, 0.25, 0.26]],
[[ 0.111, 0.121, 0.131],
[ 0.211, 0.221, 0.231],
[0.241, 0.251, 0.261]],
[[0.112, 0.122, 0.132],
[0.212, 0.222, 0.232],
[0.242, 0.252, 0.262]]]])
weight = torch.tensor([[[[1., 0.],
[0., 0.]]],
[[[0., 1.],
[0., 0.]]],
[[[0., 0.],
[1., 0.]]]])
# after transpose
# weight_ = torch.tensor([[[[0., 0.],
# [0., 1.]]],
# [[[0., 0.],
# [1., 0.]]],
# [[[0., 1.],
# [0., 0.]]]])
output = F.conv_transpose2d(input,weight,bias=None,stride=2,dilation=2,padding=1)
output
tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.5530, 0.0000, 0.5830, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.7130, 0.0000, 0.7430, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]])
对于output[0][1][1]=0.553 = input[0][1][1] + input[1][0][1] + input[2][1][0] = 0.22+0.211+0.122。经验证本文理解的实现过程有效。
这些个人理解其实是我需要以C语言复现反卷积才想的,其中还有group参数我没加入考虑,不过足以应对多数情况了。如有我考虑不周或有问题的地方,欢迎大家交流。
参考
转置卷积(ConvTranspose2d)的具体实现 - 知乎
ConvTranspose2d — PyTorch 1.13 documentation