Pytorch 实现自定义卷积:以 2.5 维卷积(2.5D Convolution)为例

        在用 Pytorch 实现各种卷积神经网络的时候,一般用到的卷积层都是系统自带的 2 维卷积 torch.nn.Conv2d,或者较少用到的 1 维卷积 torch.nn.Conv1d3 维卷积 torch.nn.Conv3d,这些卷积提供了足够的参数,使得实现带洞卷积(Atrous Convolution)深度可分离卷积(Depthwise Seperable Convolution)等特殊卷积都易如反掌。但有时候,为了某些特殊的需求,不能直接使用经典的卷积层,而是要自定义的实现某种新的卷积运算,比如可形变卷积(Deformable Convolution),因此学会从底层实现自定义卷积层是必要且必须的。

        本文试着提供一个自定义卷积层的简单教程,为了有针对性和实用性,以实现 2.5 维卷积RGB-D 图像语义分割论文:2.5D Convolution for RGB-D Semantic Segmentation)为例。本文是在参考了论文 Pixel-Adaptive Convolutional Neural Networks 的开源项目 pacnet 的基础上实现的,在此对作者表示感谢。

        本文的所有代码都见下文,也可以访问 [GitHub:稍后放出]。

一、2.5 维卷积原理

        对于卷积核为 的 2 维卷积,计它的感受野大小为 ,以 为中心的标准二维卷积计算如下:

标准的二维卷积计算公式,来源:2.5D Convolution for RGB-D Semantic Segmentation(下同)

类似的,标准的 3 维卷积 计算如下():

标准的三维卷积计算公式

        二维卷积(三维卷积)处理图像(视频)数据已经非常成熟,应用十分广泛。对于带有深度信息的 RGB-D 图像的语义分割,如果把深度信息当成一个额外的通道,那么直接使用二维卷积来实现语义分割模型即可。然而,这样做会忽视深度信息中隐藏的几何结构特征,因此有必要设计一种新颖的卷积方式来充分使用深度信息中的几何特征,论文(2.5D Convolution for RGB-D Semantic Segmentation)作者们就设计了一种称为 2.5 维卷积的操作:

2.5 维卷积计算公式

其中 , 为深度信息, 为 个 2 维卷积核的参数, 的计算公式为:

掩模操作

        根据以上公式,如果输入的特征通道数为 ,输出通道数为 ,那么容易知道:

  • 2 维卷积核的参数量:
  • 3 维卷积核的参数量:
  • 2.5 维卷积核的参数量:

如果输入、输出的分辨率都是 (或者 ),那么(大约):

  • 2 维卷积的计算量:
  • 3 维卷积的计算量:
  • 2.5 维卷积的计算量:

显然,虽然相比于 2 维卷积来说,2.5 维卷积的参数量和计算量都要大,但对比 3 维卷积来说,在参数量一致的情况下,2.5 维卷积的计算量却小得多。因此,从——性能上优于 2 维卷积,计算量上优于 3 维卷积——的角度看,2.5 维卷积是有意义的

二、2.5 维卷积实现

        严格按照公式 (4-7)来实现,2.5 维卷积的实现代码为(命名为:conv2_5d.py):

# -*- coding: utf-8 -*-
"""
Created on Wed Nov 20 18:58:19 2019

@author: lijingxiong

Implementation of 2.5D convolution:
    paper: 2.5D Convolution for RGB-D Semantic Segmentation.

Reference: https://github.com/NVlabs/pacnet/blob/master/pac.py
"""

import math
import torch
        
        
class RepeatKernelConvFn(torch.autograd.function.Function):
    """2.5D convolution with kernel.
    """
        
    @staticmethod
    def forward(ctx, inputs, kernel, weight, bias=None, stride=1, padding=0, 
                dilation=1):
        """Forward computation.
        
        Args:
            inputs: A tensor with shape [batch, channels, height, width] 
                representing a batch of images.
            kernel: A tensor with shape [k, batch, channels, N, N, k, k],
                where k = kernel_size and N = number of slide windows.
            weight: A tensor with shape [k, out_channels, in_channels, 
                kernel_size, kernel_size].
            bias: None or a tensor with shape [out_channels].
            
        Returns:
            outputs: A tensor with shape [batch, out_channels, height, width].
        """
        (batch_size, channels), input_size = inputs.shape[:2], inputs.shape[2:]
        ctx.in_channels = channels
        ctx.input_size = input_size
        ctx.kernel_size = tuple(weight.shape[-2:])
        ctx.dilation = torch.nn.modules.utils._pair(dilation)
        ctx.padding = torch.nn.modules.utils._pair(padding)
        ctx.stride = torch.nn.modules.utils._pair(stride)
        
        needs_input_grad = ctx.needs_input_grad
        ctx.save_for_backward(
            inputs if (needs_input_grad[1] or needs_input_grad[2]) else None,
            kernel if (needs_input_grad[0] or needs_input_grad[2]) else None,
            weight if (needs_input_grad[0] or needs_input_grad[1]) else None)
        ctx._backend = torch._thnn.type2backend[inputs.type()]
        
        # Slide windows, [batch, channels x kernel_size x kernel_size, N x N],
        # where N is the number of slide windows.
        inputs_wins = torch.nn.functional.unfold(inputs, ctx.kernel_size, 
                                                 ctx.dilation, ctx.padding,
                                                 ctx.stride)

        inputs_wins = inputs_wins.view(
            1, batch_size, channels, *kernel.shape[3:])
        inputs_mul_kernel = inputs_wins * kernel
                
        # Matrix multiplication
        outputs = torch.einsum(
            'hijklmn,hojmn->iokl', (inputs_mul_kernel, weight))
        
        if bias is not None:
            outputs += bias.view(1, -1, 1, 1)
        return outputs
        
    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(ctx, grad_outputs):
        grad_inputs = grad_kernel = grad_weight = grad_bias = None
        batch_size, out_channels = grad_outputs.shape[:2]
        output_size = grad_outputs.shape[2:]
        in_channels = ctx.in_channels
        
        # Compute gradients
        inputs, kernel, weight = ctx.saved_tensors
        if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
            grad_inputs_mul_kernel = torch.einsum('iokl,hojmn->hijklmn',
                                                  (grad_outputs, weight))
        if ctx.needs_input_grad[1] or ctx.needs_input_grad[2]:
            inputs_wins = torch.nn.functional.unfold(inputs, ctx.kernel_size, 
                                                     ctx.dilation, ctx.padding,
                                                     ctx.stride)
            inputs_wins = inputs_wins.view(1, batch_size, in_channels,
                                           output_size[0], output_size[1],
                                           ctx.kernel_size[0], 
                                           ctx.kernel_size[1])
        if ctx.needs_input_grad[0]:
            grad_inputs = grad_outputs.new()
            grad_inputs_wins = grad_inputs_mul_kernel * kernel
            grad_inputs_wins = grad_inputs_wins.view(
                ctx.kernel_size[0], batch_size, -1, output_size[0], output_size[1])
            ctx._backend.Im2Col_updateGradInput(ctx._backend.library_state,
                                                grad_inputs_wins,
                                                grad_inputs,
                                                ctx.input_size[0],
                                                ctx.input_size[1],
                                                ctx.kernel_size[0],
                                                ctx.kernel_size[1],
                                                ctx.dilation[0], 
                                                ctx.dilation[1],
                                                ctx.padding[0], 
                                                ctx.padding[1],
                                                ctx.stride[0],
                                                ctx.stride[1])
        if ctx.needs_input_grad[1]:
            grad_kernel = inputs_wins * grad_inputs_mul_kernel
            grad_kernel = grad_kernel.sum(dim=1, keepdim=True)
        if ctx.needs_input_grad[2]:
            inputs_mul_kernel = inputs_wins * kernel
            grad_weight = torch.einsum('iokl,hijklmn->hojmn',
                                       (grad_outputs, inputs_mul_kernel))
        if ctx.needs_input_grad[3]:
            grad_bias = torch.einsum('iokl->o', (grad_outputs,))
        return (grad_inputs, grad_kernel, grad_weight, grad_bias, None, None,
                None)
        
        
class DepthKernelFn(torch.autograd.function.Function):
    """Compute mask in paper: 
        2.5D convolution for rgb-d semantic segmentation.
    """
    
    @staticmethod
    def forward(ctx, depth, f, kernel_size, stride, padding, dilation):
        """Forward computation.
        
        Args:
            depth: A tensor with shape [batch, 1, height, width] representing
                a batch of depth maps.
            f: A constant.
            
        Returns:
            A tensor with shape [k, batch, 1, N, N, k, k], where 
            k = kernel_size and N = number of slide windows.
        """
        ctx.kernel_size = torch.nn.modules.utils._pair(kernel_size)
        ctx.stride = torch.nn.modules.utils._pair(stride)
        ctx.padding = torch.nn.modules.utils._pair(padding)
        ctx.dilation = torch.nn.modules.utils._pair(dilation)
        
        batch_size, channels, in_height, in_width = depth.shape
        out_height = (in_height + 2 * ctx.padding[0] - 
                      ctx.dilation[0] * (ctx.kernel_size[0] - 1)
                      -1) // ctx.stride[0] + 1
        out_width = (in_width + 2 * ctx.padding[1] - 
                     ctx.dilation[1] * (ctx.kernel_size[1] - 1)
                     -1) // ctx.stride[1] + 1
        
        depth_wins = torch.nn.functional.unfold(depth, ctx.kernel_size,
                                                ctx.dilation, ctx.padding,
                                                ctx.stride)
        depth_wins = depth_wins.view(batch_size, channels, out_height, 
                                     out_width, ctx.kernel_size[0],
                                     ctx.kernel_size[1])
        s_wins = depth_wins / f
        
        kernels = []
        center_y, center_x = ctx.kernel_size[0] // 2, ctx.kernel_size[1] // 2
        for l in range(ctx.kernel_size[0]):
            z_l = depth_wins + (l - (ctx.kernel_size[0] - 1) / 2) * s_wins
            z_l_0 = z_l.contiguous()[:, :, :, :, center_y:center_y + 1,
                                     center_x:center_x + 1]
            s_0 = s_wins.contiguous()[:, :, :, :, center_y:center_y + 1,
                                      center_x:center_x + 1]
            mask_l_ge = torch.where(depth_wins >= z_l_0 - s_0 / 2,
                                    torch.full_like(depth_wins, 1),
                                    torch.full_like(depth_wins, 0))
            mask_l_lt = torch.where(depth_wins < z_l_0 + s_0 / 2,
                                    torch.full_like(depth_wins, 1),
                                    torch.full_like(depth_wins, 0))
            mask_l = torch.where(mask_l_ge == mask_l_lt,
                                 mask_l_ge,
                                 torch.full_like(depth_wins, 0))
            kernels.append(mask_l.unsqueeze(dim=0))
        return torch.cat(kernels, dim=0)
    
    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(ctx, grad_outputs):
        return 0, None, None, None, None, None
    
    
class Conv2_5d(torch.nn.Module):
    """Implementation of 2.5D convolution."""
    
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, bias=True):
        """Constructor."""
        super(Conv2_5d, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = torch.nn.modules.utils._pair(kernel_size)
        self.stride = torch.nn.modules.utils._pair(stride)
        self.padding = torch.nn.modules.utils._pair(padding)
        self.dilation = torch.nn.modules.utils._pair(dilation)
        
        # Parameters: weight, bias
        self.weight = torch.nn.parameter.Parameter(
            torch.Tensor(kernel_size, out_channels, in_channels, kernel_size,
                         kernel_size))
        if bias:
            self.bias = torch.nn.parameter.Parameter(
                torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
            
        # Initialization
        self.reset_parameters()
        
    def forward(self, inputs, depth, f=1):
        """Forward computation.
        
        Args:
            inputs: A tensor with shape [batch, in_channels, height, width] 
                representing a batch of images.
            depth: A tensor with shape [batch, 1, height, width] representing
                    a batch of depth maps.
            f: A constant.
            
        Returns:
            outputs: A tensor with shape [batch, out_channels, height, width].
        """
        kernel = DepthKernelFn.apply(depth, f, self.kernel_size, self.stride,
                                     self.padding, self.dilation)
        
        outputs = RepeatKernelConvFn.apply(inputs, kernel, self.weight,
                                           self.bias, self.stride,
                                           self.padding, self.dilation)
        return outputs
    
    def extra_repr(self):
        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        return s.format(**self.__dict__)
    
    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            torch.nn.init.uniform_(self.bias, -bound, bound)

        实现自定义卷积层的要点是:

  • 底层计算需要继承 torch.autograd.function.Function

        定义该类的子类时,需要重载它的两个函数:forwardbackward 函数,分别用于前向传播和反向传播的计算。另外,forward 函数返回值的个数等于 backward 参数的个数(不计算 ctx),而 backward 返回值的个数则等于 forward 函数的参数个数(不计算 ctx),即两者的输入输出是一一对应的。顾名思义,backward 函数是利用链式法则forward 函数的所有输入求梯度,如果某个输入不需要求梯度,那么直接给该参数的梯度赋值为 None 即可。

  • 使用 torch.nn.functional.unfold 函数将数据按照滑动窗口分块:

        对于批量 、通道数、分辨率 的输入 ,形状为: ,如果卷积核大小(kernel size)、填充大小(padding)、步幅(stride)、空洞率(dilation)分别为 ,那么该函数的输出大小为:,是一个 3 维张量,其中:

  • 使用 torch.einsum 函数对张量按照卷积运算求和

        根据爱因斯坦和式约定,上下标一致的数据可以省略求和号,如:

把这一约定用符合表示并计算出来就是 einsum 函数:

torch.einsum('i,i->', (a, b))

比如:

a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 3])
s=torch.einsum('i,i->', (a, b))
# s = tensor(14)

        结合以上两个函数,对于形状为 的深度信息 ,对于固定的 ,根据公式 (7),用 torch.nn.functional.unfold 函数得到形状为 的 (已通过 view 函数改变形状),对所有的 拼接 ,得到形状为 的张量。同理,对于 ,经过滑动窗口操作之后形状为 (额外加第 1 维)。这两个 7 维张量经过元素级的乘法得到新的 7 维张量,形状为 ,然后按照公式 (4)用 torch.einsum 函数对这个 7 维张量和形状为 的权重参数张量求和:torch.einsum('hijklmn,hojmn->iokl', (·, ·)),得到形状为 的输出,其中 为输出通道数。

        以上过程就是继承了 torch.autograd.function.Function 类的两个类: DepthKernelFnRepeatKernelConvFnforward 函数的内容。调用这些类时直接使用 .apply() 函数即可。而 backward 函数就是要对 forward 函数的计算利用链式法则求梯度,因此无需赘言。

        当前向传播和反向传播的计算都定义清楚了之后,还需要将它们封装成一个自定义卷积层,这就是类 Conv2_5d,即最终用来调用的 2.5 维卷积层。它的定义跟平时卷积网络的定义类似,都是直接继承 torch.nn.Module 类和重载 forward 函数,不同的是需要定义权重参数和偏置参数(如果需要的话):

self.weight = torch.nn.parameter.Parameter(torch.Tensor(shape))
self.bias = torch.nn.parameter.Parameter(torch.Tensor(shape))

并适当的初始化它们(见 reset_parameters 函数)。

三、2.5 维卷积实现代码的验证

        为了验证以上实现的代码在反向传播时不会报错,定义一个两层的简单网络来验证如下(命名为:conv2_5d_test.py):

# -*- coding: utf-8 -*-
"""
Created on Wed Nov 27 13:41:23 2019

@author: lijingxiong
"""

import torch

import conv2_5d


class ConvTest(torch.nn.Module):
    """A mini networt to test Conv2_5d in forward and backword computation."""
    
    def __init__(self, num_classes=2):
        super(ConvTest, self).__init__()
        
        self._head_conv = conv2_5d.Conv2_5d(in_channels=3, 
                                            out_channels=32, 
                                            kernel_size=5, 
                                            padding=2, 
                                            bias=False)
        self._pred_conv = torch.nn.Conv2d(in_channels=32,
                                          out_channels=num_classes,
                                          kernel_size=3,
                                          padding=1,
                                          bias=False)
        self._batch_norm = torch.nn.BatchNorm2d(num_features=num_classes,
                                                momentum=0.995)
        
    def forward(self, x, z, f=1):
        x = self._head_conv(x, z, f)
        x = self._pred_conv(x)
        x = self._batch_norm(x)
        return x
    
    
if __name__ == '__main__':
    # Device configuration
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    model = ConvTest().to(device)
    
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    
    num_steps = 100
    for i in range(num_steps):
        images = torch.rand((2, 3, 64, 64)).to(device)
        depth = torch.rand((2, 1, 64, 64)).to(device)
        labels = torch.LongTensor(
            torch.full((2, 64, 64), 0, dtype=torch.int64)).to(device)
        
        # Forward pass
        outputs = model(images, depth)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        print('Step: {}/{}, Loss: {:.4f}'.format(i+1, num_steps, loss.item()))

        直接执行 :

python3 conv2_5d_test.py

代码正常结束,且损失逐渐减小,(暂时)认为代码是正确的。

你可能感兴趣的:(Pytorch 实现自定义卷积:以 2.5 维卷积(2.5D Convolution)为例)