在两年前的这篇文章Deformable Convolution(可变形卷积)代码解析(有错误,修改中)中,当时对可变形卷积进行了代码解读,后来被网友指出其中的解释是错的,里面引用的keras版本的代码实现也是错的,后来查看了一些其他人的文章,其中正确和错误的解释都掺杂在一起,不少人和我之前的错误理解是一样的,一直不敢确定所以文章一直没有修正,现在应该是搞懂了,在这里对当时的错误解析进行纠正。
这里讲解的代码实现来自于https://github.com/4uiiurz1/pytorch-deform-conv-v2,其中当modulation=True时,就是v2,否则就是v1,这里只考虑v1。
可变形卷积的流程就是相比于普通的卷积,多学了一层偏移offset,偏移是通过单独的一层卷积学习到的,然后将学到的偏移offset与原始的输入特征进行相加,因为偏移有可能是小数,因此通过双线性插值得到偏移后的特征值,然后再经过原本的那层普通卷积就得到了最终的输出。
之前的错误理解:对于输入shape为(b, c, w, h)的feature map,经过一层卷积后学习到的偏移feature map的shape为(b, 2, w, h),相当于输入特征图上每个像素位置都学习x,y两个方向的偏移,然后将这个偏移特征图与原始输入进行相加,(注意并不是与原始输入特征图直接相加,而是与原始特征图上每个像素的坐标进行相加),这样就得到了原始每个位置偏移后的坐标,然后进行bilinear插值得到偏移后坐标处的值,这里还要注意的是通道共享偏移,因此插值后的特征图和原始输入的shape一样都是(b, c, w, h),然后再经过原本的普通卷积层就得到了可变形卷积层的最终输出。
上面错误的地方在于:偏移特征图的shape应该是(b, N*2, w, h),其中N是原来的普通卷积的大小,比如如果普通卷积是一个3x3的卷积,这里N=3x3=9。这是因为可变形卷积中变形的是卷积核的形状,普通的3x3卷积如下图(a)所示,是一个规则的3x3方格,且采样点之间的间隔为1。(b),(c),(d)是变形后的3x3卷积,9个采样点由绿点变成了蓝点。对于一个(b, c_in, w, h)的输入特征图,假设原始卷积为3x3,padding=1,stride=1,输出特征图shape为(b, c_out, w, h),即特征图的宽高不变。这里卷积窗口沿输入spatial维度进行滑动时总共滑动了w*h个位置,当卷积为可变形卷积时,卷积核每滑动到一个位置处,原始的3x3方格采样都变成了偏移后的位置,因为卷积核有9个采样位置,每个又有x,y两个方向的偏移,因此偏移输出特征图的shape应该为(b, 18, w, h)。
总结如下
完整实现如下所示,假设输入x.shape=(1, 64, 5, 5)。
import torch
from torch import nn
class DeformConv2d(nn.Module):
def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False):
"""
Args:
modulation (bool, optional): If True, Modulated Defomable Convolution (Deformable ConvNets v2).
"""
super(DeformConv2d, self).__init__()
self.kernel_size = kernel_size
self.padding = padding
self.stride = stride
self.zero_padding = nn.ZeroPad2d(padding)
self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)
self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
nn.init.constant_(self.p_conv.weight, 0)
self.p_conv.register_backward_hook(self._set_lr)
self.modulation = modulation
if modulation:
self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
nn.init.constant_(self.m_conv.weight, 0)
self.m_conv.register_backward_hook(self._set_lr)
@staticmethod
def _set_lr(module, grad_input, grad_output):
grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))
def forward(self, x): # (1,64,5,5)
offset = self.p_conv(x) # (1, 18, 5, 5)
if self.modulation:
m = torch.sigmoid(self.m_conv(x))
dtype = offset.data.type()
ks = self.kernel_size # 3
N = offset.size(1) // 2 # 9
if self.padding:
x = self.zero_padding(x) # (1, 64, 5, 5) -> (1, 64, 7, 7)
# (b, 2N, h, w)
p = self._get_p(offset, dtype)
# print(p.shape)
# print(p)
# (b, h, w, 2N)
p = p.contiguous().permute(0, 2, 3, 1)
q_lt = p.detach().floor()
q_rb = q_lt + 1
q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long() # (1,5,5,18)
q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long() # (1,5,5,18)
q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1) # (1,5,5,18)
q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1) # (1,5,5,18)
# clip p
p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1) # (1,5,5,18)
# bilinear kernel (b, h, w, N)
g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))
# (b, c, h, w, N)
x_q_lt = self._get_x_q(x, q_lt, N)
x_q_rb = self._get_x_q(x, q_rb, N)
x_q_lb = self._get_x_q(x, q_lb, N)
x_q_rt = self._get_x_q(x, q_rt, N)
# (b, c, h, w, N)
x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
g_rb.unsqueeze(dim=1) * x_q_rb + \
g_lb.unsqueeze(dim=1) * x_q_lb + \
g_rt.unsqueeze(dim=1) * x_q_rt
# modulation
if self.modulation:
m = m.contiguous().permute(0, 2, 3, 1)
m = m.unsqueeze(dim=1)
m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)
x_offset *= m
x_offset = self._reshape_x_offset(x_offset, ks)
out = self.conv(x_offset)
return out
def _get_p_n(self, N, dtype):
p_n_x, p_n_y = torch.meshgrid(
torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))
# (2N, 1)
p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
p_n = p_n.view(1, 2*N, 1, 1).type(dtype)
return p_n
def _get_p_0(self, h, w, N, dtype):
p_0_x, p_0_y = torch.meshgrid(
torch.arange(1, h*self.stride+1, self.stride),
torch.arange(1, w*self.stride+1, self.stride))
p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
return p_0
def _get_p(self, offset, dtype):
N, h, w = offset.size(1)//2, offset.size(2), offset.size(3) # 9,5,5
# (1, 2N, 1, 1)
p_n = self._get_p_n(N, dtype) # 3x3卷积内9个点相对中心点(0,0)的偏移坐标
# tensor([[[[-1.]],
#
# [[-1.]],
#
# [[-1.]],
#
# [[ 0.]],
#
# [[ 0.]],
#
# [[ 0.]],
#
# [[ 1.]],
#
# [[ 1.]],
#
# [[ 1.]],
#
# [[-1.]],
#
# [[ 0.]],
#
# [[ 1.]],
#
# [[-1.]],
#
# [[ 0.]],
#
# [[ 1.]],
#
# [[-1.]],
#
# [[ 0.]],
#
# [[ 1.]]]])
# (1, 2N, h, w)
p_0 = self._get_p_0(h, w, N, dtype) # 输入特征图上的每个像素点的原始坐标
# tensor([[[[1., 1., 1., 1., 1.],
# [2., 2., 2., 2., 2.],
# [3., 3., 3., 3., 3.],
# [4., 4., 4., 4., 4.],
# [5., 5., 5., 5., 5.]],
#
# [[1., 1., 1., 1., 1.],
# [2., 2., 2., 2., 2.],
# [3., 3., 3., 3., 3.],
# [4., 4., 4., 4., 4.],
# [5., 5., 5., 5., 5.]],
#
# [[1., 1., 1., 1., 1.],
# [2., 2., 2., 2., 2.],
# [3., 3., 3., 3., 3.],
# [4., 4., 4., 4., 4.],
# [5., 5., 5., 5., 5.]],
#
# [[1., 1., 1., 1., 1.],
# [2., 2., 2., 2., 2.],
# [3., 3., 3., 3., 3.],
# [4., 4., 4., 4., 4.],
# [5., 5., 5., 5., 5.]],
#
# [[1., 1., 1., 1., 1.],
# [2., 2., 2., 2., 2.],
# [3., 3., 3., 3., 3.],
# [4., 4., 4., 4., 4.],
# [5., 5., 5., 5., 5.]],
#
# [[1., 1., 1., 1., 1.],
# [2., 2., 2., 2., 2.],
# [3., 3., 3., 3., 3.],
# [4., 4., 4., 4., 4.],
# [5., 5., 5., 5., 5.]],
#
# [[1., 1., 1., 1., 1.],
# [2., 2., 2., 2., 2.],
# [3., 3., 3., 3., 3.],
# [4., 4., 4., 4., 4.],
# [5., 5., 5., 5., 5.]],
#
# [[1., 1., 1., 1., 1.],
# [2., 2., 2., 2., 2.],
# [3., 3., 3., 3., 3.],
# [4., 4., 4., 4., 4.],
# [5., 5., 5., 5., 5.]],
#
# [[1., 1., 1., 1., 1.],
# [2., 2., 2., 2., 2.],
# [3., 3., 3., 3., 3.],
# [4., 4., 4., 4., 4.],
# [5., 5., 5., 5., 5.]],
#
# [[1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.]],
#
# [[1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.]],
#
# [[1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.]],
#
# [[1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.]],
#
# [[1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.]],
#
# [[1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.]],
#
# [[1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.]],
#
# [[1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.]],
#
# [[1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.],
# [1., 2., 3., 4., 5.]]]])
p = p_0 + p_n + offset
# p = p_0 + p_n
return p
def _get_x_q(self, x, q, N):
b, h, w, _ = q.size()
padded_w = x.size(3)
c = x.size(1)
# (b, c, h*w)
x = x.contiguous().view(b, c, -1)
# (b, h, w, N)
index = q[..., :N]*padded_w + q[..., N:] # offset_x*w + offset_y
# (b, c, h*w*N)
index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
return x_offset
@staticmethod
def _reshape_x_offset(x_offset, ks):
b, c, h, w, N = x_offset.size()
x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)
return x_offset
if __name__ == '__main__':
deformconv2d = DeformConv2d(
64, 128, kernel_size=3, padding=1, stride=1, bias=None, modulation=False
)
_input = torch.ones((1, 64, 5, 5))
result = deformconv2d(_input)
print(result.shape)
print(result)
forward函数中首先通过self.p_conv得到偏移offset,输出shape为(1, 18, 5, 5)。
self.modulation=True时,是v2新增的东西,这里不关心。
p=self.get_p是用来得到偏移后的采样坐标位置的,self.get_p中又包括p_n=self._get_p_n和p_0=self._get_p_0,前者是用来得到原始3x3卷积核内每个位置的坐标的,其中中心当做原点坐标为(0, 0),左上角为(-1, -1),右下角为(1, 1),输出shape为(1, 18, 1, 1)。后者是用来得到卷积核在输入特征图上滑动时,卷积核的中心点相对于输入特征图原点的坐标,输入特征图的原点就是左上角,注意因为原始输入加上了padding=1,因此输入的第一行的y坐标为1,第一列的x坐标为1。最终p=p_0+p_n+offset就得到了卷积核在每个滑动位置处的偏移后的采样坐标,这里的采样坐标是相对于输入特征图左上角原点的,shape=(1, 18, 5, 5)
这里偏移后的坐标是小数,还需要通过双线性插值得到坐标处的值。关于双线性插值可参考Deformable Convolution(可变形卷积)代码解析(有错误,修改中)中的介绍。其中p向下取整就得到了左上角处的坐标q_lt,左上角坐标+1得到右下角坐标q_rb,左上角的x坐标和右下角的y坐标拼接得到左下角坐标q_lb,右下角的x坐标和左上角的y坐标拼接得到右上角坐标q_rt。如下图,x(p)的值由最近的四个整数坐标处的值x(q1)、x(q2)、x(q3)、x(q4)以及距离u、v通过双线性插值计算得到。
t1 = (1-u)*x(q1) + u*x(q2)
t2 = (1-u)*x(q3) + u*x(q4)
x(p) = (1-v)*t1 + v*t2
# x(p) = (1-v)*(1-u)*x(q1) + (1-v)*u*x(q2) + v*(1-u)*x(q3) + v*u*x(q4)
四个点的权重g_lt、g_rb、g_lb、g_rt分别对应上面代码中的(1-v)*(1-u)、v*u、v*(1-u)、(1-v)*u。然后通过self._get_x_q()从输入特征图x中得到四个点处的值,分别对应上图中的x(q1)、x(q2)、x(q3)、x(q4)。然后相乘并求和就得到了偏移后坐标处的值x_offset即上面的x(p)。
最后,x_offset的shape=(1, 64, 5, 5, 9),这里的9即对应原始的3x3卷积滑动到5x5中的每个像素位置处,对应的输入特征图上的9个采样值,然后通过self._reshape_x_offset()将x_offset reshape成(1, 64, 5*3, 5*3),然后再经过原始的那层普通卷积self.conv()就得到了最终输出,需要注意原始卷积的stride由1改成了3。