DCN v1 可变形卷积v1解析(修正篇)

在两年前的这篇文章Deformable Convolution(可变形卷积)代码解析(有错误,修改中)中,当时对可变形卷积进行了代码解读,后来被网友指出其中的解释是错的,里面引用的keras版本的代码实现也是错的,后来查看了一些其他人的文章,其中正确和错误的解释都掺杂在一起,不少人和我之前的错误理解是一样的,一直不敢确定所以文章一直没有修正,现在应该是搞懂了,在这里对当时的错误解析进行纠正。

这里讲解的代码实现来自于https://github.com/4uiiurz1/pytorch-deform-conv-v2,其中当modulation=True时,就是v2,否则就是v1,这里只考虑v1。

可变形卷积的流程就是相比于普通的卷积,多学了一层偏移offset,偏移是通过单独的一层卷积学习到的,然后将学到的偏移offset与原始的输入特征进行相加,因为偏移有可能是小数,因此通过双线性插值得到偏移后的特征值,然后再经过原本的那层普通卷积就得到了最终的输出。

之前的错误理解:对于输入shape为(b, c, w, h)的feature map,经过一层卷积后学习到的偏移feature map的shape为(b, 2, w, h),相当于输入特征图上每个像素位置都学习x,y两个方向的偏移,然后将这个偏移特征图与原始输入进行相加,(注意并不是与原始输入特征图直接相加,而是与原始特征图上每个像素的坐标进行相加),这样就得到了原始每个位置偏移后的坐标,然后进行bilinear插值得到偏移后坐标处的值,这里还要注意的是通道共享偏移,因此插值后的特征图和原始输入的shape一样都是(b, c, w, h),然后再经过原本的普通卷积层就得到了可变形卷积层的最终输出。

上面错误的地方在于:偏移特征图的shape应该是(b, N*2, w, h),其中N是原来的普通卷积的大小,比如如果普通卷积是一个3x3的卷积,这里N=3x3=9。这是因为可变形卷积中变形的是卷积核的形状,普通的3x3卷积如下图(a)所示,是一个规则的3x3方格,且采样点之间的间隔为1。(b),(c),(d)是变形后的3x3卷积,9个采样点由绿点变成了蓝点。对于一个(b, c_in, w, h)的输入特征图,假设原始卷积为3x3,padding=1,stride=1,输出特征图shape为(b, c_out, w, h),即特征图的宽高不变。这里卷积窗口沿输入spatial维度进行滑动时总共滑动了w*h个位置,当卷积为可变形卷积时,卷积核每滑动到一个位置处,原始的3x3方格采样都变成了偏移后的位置,因为卷积核有9个采样位置,每个又有x,y两个方向的偏移,因此偏移输出特征图的shape应该为(b, 18, w, h)。

DCN v1 可变形卷积v1解析(修正篇)_第1张图片

总结如下

  1. 可变形卷积学习的是卷积核的变形采样能力。比如假设目标物体是一个菱形的形状,此时卷积核的形状是一个菱形可能会比原始的方格形状能更好的提取目标的特征。
  2. 因为卷积核是在输入特征图上进行滑动采样,然后与窗口内对应的输入特征进行加权求和得到输出。因此可以将卷积核形状的偏移等价为输入特征图上采样位置的偏移,实际上也是这么做的。

代码解析

完整实现如下所示,假设输入x.shape=(1, 64, 5, 5)。

import torch
from torch import nn


class DeformConv2d(nn.Module):
    def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False):
        """
        Args:
            modulation (bool, optional): If True, Modulated Defomable Convolution (Deformable ConvNets v2).
        """
        super(DeformConv2d, self).__init__()
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.zero_padding = nn.ZeroPad2d(padding)
        self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)

        self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
        nn.init.constant_(self.p_conv.weight, 0)
        self.p_conv.register_backward_hook(self._set_lr)

        self.modulation = modulation
        if modulation:
            self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
            nn.init.constant_(self.m_conv.weight, 0)
            self.m_conv.register_backward_hook(self._set_lr)

    @staticmethod
    def _set_lr(module, grad_input, grad_output):
        grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
        grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))

    def forward(self, x):  # (1,64,5,5)
        offset = self.p_conv(x)  # (1, 18, 5, 5)
        if self.modulation:
            m = torch.sigmoid(self.m_conv(x))

        dtype = offset.data.type()
        ks = self.kernel_size  # 3
        N = offset.size(1) // 2  # 9

        if self.padding:
            x = self.zero_padding(x)  # (1, 64, 5, 5) -> (1, 64, 7, 7)

            # (b, 2N, h, w)
        p = self._get_p(offset, dtype)
        # print(p.shape)
        # print(p)

        # (b, h, w, 2N)
        p = p.contiguous().permute(0, 2, 3, 1)

        q_lt = p.detach().floor()
        q_rb = q_lt + 1

        q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()  # (1,5,5,18)
        q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()  # (1,5,5,18)
        q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)  # (1,5,5,18)
        q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)  # (1,5,5,18)

        # clip p
        p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)  # (1,5,5,18)

        # bilinear kernel (b, h, w, N)
        g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
        g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
        g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
        g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))

        # (b, c, h, w, N)
        x_q_lt = self._get_x_q(x, q_lt, N)
        x_q_rb = self._get_x_q(x, q_rb, N)
        x_q_lb = self._get_x_q(x, q_lb, N)
        x_q_rt = self._get_x_q(x, q_rt, N)

        # (b, c, h, w, N)
        x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
        g_rb.unsqueeze(dim=1) * x_q_rb + \
        g_lb.unsqueeze(dim=1) * x_q_lb + \
        g_rt.unsqueeze(dim=1) * x_q_rt

        # modulation
        if self.modulation:
        m = m.contiguous().permute(0, 2, 3, 1)
        m = m.unsqueeze(dim=1)
        m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)
        x_offset *= m

        x_offset = self._reshape_x_offset(x_offset, ks)
        out = self.conv(x_offset)

        return out

        def _get_p_n(self, N, dtype):
        p_n_x, p_n_y = torch.meshgrid(
        torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
        torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))
        # (2N, 1)
        p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
        p_n = p_n.view(1, 2*N, 1, 1).type(dtype)

        return p_n

        def _get_p_0(self, h, w, N, dtype):
        p_0_x, p_0_y = torch.meshgrid(
        torch.arange(1, h*self.stride+1, self.stride),
        torch.arange(1, w*self.stride+1, self.stride))
        p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)

        return p_0

        def _get_p(self, offset, dtype):
        N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)  # 9,5,5

        # (1, 2N, 1, 1)
        p_n = self._get_p_n(N, dtype)  # 3x3卷积内9个点相对中心点(0,0)的偏移坐标
        # tensor([[[[-1.]],
        #
        #          [[-1.]],
        #
        #          [[-1.]],
        #
        #          [[ 0.]],
        #
        #          [[ 0.]],
        #
        #          [[ 0.]],
        #
        #          [[ 1.]],
        #
        #          [[ 1.]],
        #
        #          [[ 1.]],
        #
        #          [[-1.]],
        #
        #          [[ 0.]],
        #
        #          [[ 1.]],
        #
        #          [[-1.]],
        #
        #          [[ 0.]],
        #
        #          [[ 1.]],
        #
        #          [[-1.]],
        #
        #          [[ 0.]],
        #
        #          [[ 1.]]]])
        # (1, 2N, h, w)
        p_0 = self._get_p_0(h, w, N, dtype)  # 输入特征图上的每个像素点的原始坐标
        # tensor([[[[1., 1., 1., 1., 1.],
        #           [2., 2., 2., 2., 2.],
        #           [3., 3., 3., 3., 3.],
        #           [4., 4., 4., 4., 4.],
        #           [5., 5., 5., 5., 5.]],
        #
        #          [[1., 1., 1., 1., 1.],
        #           [2., 2., 2., 2., 2.],
        #           [3., 3., 3., 3., 3.],
        #           [4., 4., 4., 4., 4.],
        #           [5., 5., 5., 5., 5.]],
        #
        #          [[1., 1., 1., 1., 1.],
        #           [2., 2., 2., 2., 2.],
        #           [3., 3., 3., 3., 3.],
        #           [4., 4., 4., 4., 4.],
        #           [5., 5., 5., 5., 5.]],
        #
        #          [[1., 1., 1., 1., 1.],
        #           [2., 2., 2., 2., 2.],
        #           [3., 3., 3., 3., 3.],
        #           [4., 4., 4., 4., 4.],
        #           [5., 5., 5., 5., 5.]],
        #
        #          [[1., 1., 1., 1., 1.],
        #           [2., 2., 2., 2., 2.],
        #           [3., 3., 3., 3., 3.],
        #           [4., 4., 4., 4., 4.],
        #           [5., 5., 5., 5., 5.]],
        #
        #          [[1., 1., 1., 1., 1.],
        #           [2., 2., 2., 2., 2.],
        #           [3., 3., 3., 3., 3.],
        #           [4., 4., 4., 4., 4.],
        #           [5., 5., 5., 5., 5.]],
        #
        #          [[1., 1., 1., 1., 1.],
        #           [2., 2., 2., 2., 2.],
        #           [3., 3., 3., 3., 3.],
        #           [4., 4., 4., 4., 4.],
        #           [5., 5., 5., 5., 5.]],
        #
        #          [[1., 1., 1., 1., 1.],
        #           [2., 2., 2., 2., 2.],
        #           [3., 3., 3., 3., 3.],
        #           [4., 4., 4., 4., 4.],
        #           [5., 5., 5., 5., 5.]],
        #
        #          [[1., 1., 1., 1., 1.],
        #           [2., 2., 2., 2., 2.],
        #           [3., 3., 3., 3., 3.],
        #           [4., 4., 4., 4., 4.],
        #           [5., 5., 5., 5., 5.]],
        #
        #          [[1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.]],
        #
        #          [[1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.]],
        #
        #          [[1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.]],
        #
        #          [[1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.]],
        #
        #          [[1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.]],
        #
        #          [[1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.]],
        #
        #          [[1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.]],
        #
        #          [[1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.]],
        #
        #          [[1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.],
        #           [1., 2., 3., 4., 5.]]]])

        p = p_0 + p_n + offset
        # p = p_0 + p_n
        return p

        def _get_x_q(self, x, q, N):
        b, h, w, _ = q.size()
        padded_w = x.size(3)
        c = x.size(1)
        # (b, c, h*w)
        x = x.contiguous().view(b, c, -1)

        # (b, h, w, N)
        index = q[..., :N]*padded_w + q[..., N:]  # offset_x*w + offset_y
        # (b, c, h*w*N)
        index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)

        x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)

        return x_offset

        @staticmethod
        def _reshape_x_offset(x_offset, ks):
        b, c, h, w, N = x_offset.size()
        x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
        x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)

        return x_offset


        if __name__ == '__main__':
        deformconv2d = DeformConv2d(
        64, 128, kernel_size=3, padding=1, stride=1, bias=None, modulation=False
        )
        _input = torch.ones((1, 64, 5, 5))
        result = deformconv2d(_input)
        print(result.shape)
        print(result)

forward函数中首先通过self.p_conv得到偏移offset,输出shape为(1, 18, 5, 5)。

self.modulation=True时,是v2新增的东西,这里不关心。

p=self.get_p是用来得到偏移后的采样坐标位置的,self.get_p中又包括p_n=self._get_p_n和p_0=self._get_p_0,前者是用来得到原始3x3卷积核内每个位置的坐标的,其中中心当做原点坐标为(0, 0),左上角为(-1, -1),右下角为(1, 1),输出shape为(1, 18, 1, 1)。后者是用来得到卷积核在输入特征图上滑动时,卷积核的中心点相对于输入特征图原点的坐标,输入特征图的原点就是左上角,注意因为原始输入加上了padding=1,因此输入的第一行的y坐标为1,第一列的x坐标为1。最终p=p_0+p_n+offset就得到了卷积核在每个滑动位置处的偏移后的采样坐标,这里的采样坐标是相对于输入特征图左上角原点的,shape=(1, 18, 5, 5)

这里偏移后的坐标是小数,还需要通过双线性插值得到坐标处的值。关于双线性插值可参考Deformable Convolution(可变形卷积)代码解析(有错误,修改中)中的介绍。其中p向下取整就得到了左上角处的坐标q_lt,左上角坐标+1得到右下角坐标q_rb,左上角的x坐标和右下角的y坐标拼接得到左下角坐标q_lb,右下角的x坐标和左上角的y坐标拼接得到右上角坐标q_rt。如下图,x(p)的值由最近的四个整数坐标处的值x(q1)、x(q2)、x(q3)、x(q4)以及距离u、v通过双线性插值计算得到。

DCN v1 可变形卷积v1解析(修正篇)_第2张图片

t1 = (1-u)*x(q1) + u*x(q2)
t2 = (1-u)*x(q3) + u*x(q4)
x(p) = (1-v)*t1 + v*t2
# x(p) = (1-v)*(1-u)*x(q1) + (1-v)*u*x(q2) + v*(1-u)*x(q3) + v*u*x(q4)

四个点的权重g_lt、g_rb、g_lb、g_rt分别对应上面代码中的(1-v)*(1-u)、v*u、v*(1-u)、(1-v)*u。然后通过self._get_x_q()从输入特征图x中得到四个点处的值,分别对应上图中的x(q1)、x(q2)、x(q3)、x(q4)。然后相乘并求和就得到了偏移后坐标处的值x_offset即上面的x(p)。

最后,x_offset的shape=(1, 64, 5, 5, 9),这里的9即对应原始的3x3卷积滑动到5x5中的每个像素位置处,对应的输入特征图上的9个采样值,然后通过self._reshape_x_offset()将x_offset reshape成(1, 64, 5*3, 5*3),然后再经过原始的那层普通卷积self.conv()就得到了最终输出,需要注意原始卷积的stride由1改成了3。 

你可能感兴趣的:(目标检测,深度学习,cnn,神经网络,目标检测,计算机视觉)