剪枝与重参第八课:ACNet、DBB、RepVGG重参

目录

  • ACNet、DBB、RepVGG重参
    • 前言
    • 1. 并行多分支结构
      • 1.1. 并行多分支结构 Demo
    • 2. ACNet
      • 2.1 ACNet简述
      • 2.2 init
      • 2.3 forward
      • 2.4 swtich to deploy
      • 2.5 get_equivalent_kernel_bias
      • 2.6 Conv2d与BN的融合(重参)
      • 2.7 Conv1x3Conv3x1Conv3x3的融合(重参)
      • 2.8 模型导出
      • 2.9 完整示例代码
    • 总结

ACNet、DBB、RepVGG重参

前言

手写AI推出的全新模型剪枝与重参课程。记录下个人学习笔记,仅供自己参考。

本次课程主要讲解ACNet的重参。

课程大纲可看下面的思维导图

剪枝与重参第八课:ACNet、DBB、RepVGG重参_第1张图片

1. 并行多分支结构

  • 并行多分支结构是近年来在深度学习领域中的一种常用技术(其作用是让特征提取得更有表征性,让模型的泛化能力更强)。以下是一些相关的里程碑论文和年份:

DenseNet

  • 论文:2016年,Densely Connected Convolutional Networks
  • 该论文提出了一种全新的网络结构,通过将每一层与之前所有层连接,形成了一种密集连接的网络结构。这种结构在训练时可以更好地利用特征重用,使得特征更具表征性。DenseNet 是多分支结构的一个重要代表

VOVNet

  • 论文:2018年,An Energy and GPU-Computation Efficient Backbone Network for Real-Time Object Detection
  • 该论文提出了一种基于局部特征汇聚的多分支结构,通过对局部特征进行自适应聚合,可以提取出更加稳定的特征,进而提高模型的泛化能力

Res2Net

  • 论文:2019年,Res2Net: A New Multi-scale Backbone Architecture
  • 该论文提出了一种基于残差块的多尺度特征提取方法,通过使用多个尺度的卷积核来提取不同尺度的特征,并在残差块中使用注意力机制来调整特征的权重,可以提高模型的性能

PeleeNet

  • 论文:2019年,Pelee:A Real-Time Object Detection System on Mobile Devices
  • 该论文提出了一种轻量级的卷积神经网络 PeleeNet,该网络使用了多分支结构来提取特征,同时使用了一种通道重排技术,可以减少网络参数和计算量,实现了在移动设备上的实时目标检测

1.1. 并行多分支结构 Demo

我们先简单来实现一个并行多分支结构的网络,如下图所示:

剪枝与重参第八课:ACNet、DBB、RepVGG重参_第2张图片

示例代码如下所示:

import torch
import torch.nn as nn

class SimpleParallelNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=4, kernel_size=5, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=4, kernel_size=7, stride=1, padding=3),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.conv = nn.Conv2d(12, 32, kernel_size=3, stride=1, padding=1)
        self.fc   = nn.Linear(8192, 10)
    
    def forward(self, x):
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        x  = torch.cat((x1, x2, x3), dim=1)
        x  = self.conv(x)
        x  = x.view(int(x.size(0)), -1)
        x  = self.fc(x)
        return x

if __name__ == "__main__":
    model = SimpleParallelNet()
    dummy_input = torch.randn((1, 3, 32, 32))
    model(dummy_input)
    output_path = './test.onnx'
    torch.onnx.export(model, dummy_input, output_path)

2. ACNet

  • 多分支结构虽然有更强的表征性,但是计算量也随之增大,有没有办法可以在提高表征性的同时又减少计算量呢?
  • 答案:训练时多分支,推理时合并为单分支,ACNet就是开端
  • 重参是指网络的参数重新组织,通常会减少参数量并提高计算效率,同时尽量保持网络性能不受影响。具体地说,可以通过一些技巧来减少网络中的卷积运算,例如通过组合多个卷积操作,使其变为一个卷积操作等等。
  • ACNet中的重参操作可以看作是一种加速卷积的方法。它将原来需要进行三次卷积操作的网络结构改为一次卷积操作,从而减少了计算量和参数量。具体来说,ACNet在进行前后特征的融合时,使用了组合卷积的方式来替代原来的三次卷积操作。

2.1 ACNet简述

ACNet提出了一种替代卷积核的方法,用来增加模型对图像翻转旋转的鲁棒性

这种方法在训练阶段中使用1x3卷积、3x1卷积和3x3卷积代替原来的3x3卷积,并将它们的计算结果进行融合得到卷积层的输出

剪枝与重参第八课:ACNet、DBB、RepVGG重参_第3张图片

这个过程,ACNet没有引入额外的超参数,也没有增加推理阶段的计算量。

在推理阶段,ACNet使用融合后的卷积核参数来初始化现有的网络,以提升模型的特征提取能力

ACNet的方法在实验中表现出强大的性能,尤其是在图像翻转和旋转的情况下,其鲁棒性远高于其它方法。这种方法的特定是引入了非对称卷积核,因此被称为Asymmetric Convolution。ACNet的方法可以为各种网络带来涨点,因为它提升了模型对图像翻转和旋转的鲁棒性,同时不会增加推理阶段的计算量。

2.2 init

ACNet的初始化函数

import torch.nn as nn

class ACNet(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, padding = 0, dilation = 1,
                 groups = 1, padding_mode = 'zeros',
                 deploy = False, use_affine = True):
        super().__init__()
        self.deploy = deploy
        if self.deploy:
            self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, kernel_size),
                                        stride=stride, padding=padding, dilation=dilation,
                                        groups=groups, bias=True, padding_mode=padding_mode)
        else:
            self.square_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, kernel_size),
                                         stride=stride, padding=padding, dilation=dilation,
                                         groups=groups, bias=True, padding_mode=padding_mode)
            
            self.square_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)

            if padding - kernel_size // 2 >= 0:
                self.crop = 0
                hor_padding = [padding - kernel_size // 2, padding]
                ver_padding = [padding, padding - kernel_size // 2]
            else:
                self.crop = kernel_size // 2 - padding
                hor_padding = [0, padding]
                ver_padding = [padding, 0]
            
            self.ver_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, 1),
                                      stride=stride, padding=ver_padding, dilation=dilation,
                                      groups=groups, bias=True, padding_mode=padding_mode)
            
            self.ver_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)

            self.hor_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, kernel_size),
                                      stride=stride, padding=hor_padding, dilation=dilation,
                                      groups=groups, bias=True, padding_mode=padding_mode)
            
            self.hor_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)

use_affine用于控制是否将 γ \gamma γ β \beta β设置为可学习参数

2.3 forward

ACNet的前向传播函数

import torch.nn as nn

class ACNet(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, padding = 0, dilation = 1,
                 groups = 1, padding_mode = 'zeros',
                 deploy = False, use_affine = True):
        super().__init__()
        pass
    
    def forward(self, input):
        if self.deploy:
            return self.fused_conv(input)
        else:
            square_outputs = self.square_conv(input)
            square_outputs = self.square_bn(square_outputs)

            if self.crop > 0:
                ver_input = input[:, :, :, self.crop:-self.crop]
                hor_input = input[:, :, self.crop:-self.crop, :]
            else:
                ver_input = input
                hor_input = input
            vertical_outputs = self.ver_conv(ver_input)
            vertical_outputs = self.ver_bn(vertical_outputs)
            horizontal_outputs = self.hor_conv(hor_input)
            horizontal_outputs = self.hor_bn(horizontal_outputs)
            result = square_outputs + vertical_outputs + horizontal_outputs
            return result

2.4 swtich to deploy

ACNet的转换函数,将训练时3个卷积核转换为推理时的一个卷积核

import torch.nn as nn

class ACNet(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, padding = 0, dilation = 1,
                 groups = 1, padding_mode = 'zeros',
                 deploy = False, use_affine = True):
        super().__init__()
        pass
        
    def forward(self, input):
        pass

    def swtich_to_deploy(self):
        self.deploy = True
        deploy_k, deploy_b = self.get_equivalent_kernel_bias()
        self.fused_conv = nn.Conv2d(in_channels=self.square_conv.in_channels,
                                    out_channels=self.square_conv.out_channels,
                                    kernel_size=self.square_conv.kernel_size,
                                    stride=self.square_conv.stride,
                                    padding=self.square_conv.padding,
                                    dilation=self.square_conv.dilation,
                                    groups=self.square_conv.groups,
                                    bias=True,
                                    padding_mode=self.square_conv.padding_mode)

        self.__delattr__('square_conv')
        self.__delattr__('square_bn')
        self.__delattr__('hor_conv')
        self.__delattr__('hor_bn')
        self.__delattr__('ver_conv')
        self.__delattr__('ver_bn')

        self.fused_conv.weight.data = deploy_k
        self.fused_conv.bias.data = deploy_b

    def get_equivalent_kernel_bias(self):
        pass

2.5 get_equivalent_kernel_bias

ACNet的重参具体实现函数

import torch.nn as nn

class ACNet(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, padding = 0, dilation = 1,
                 groups = 1, padding_mode = 'zeros',
                 deploy = False, use_affine = True):
        super().__init__()
        pass
    
    def forward(self, input):
        pass

    def swtich_to_deploy(self):
        pass

    def get_equivalent_kernel_bias(self):
        hor_k, hor_b = self.fuse_bn_tensor(self.hor_conv, self.hor_bn)
        ver_k, ver_b = self.fuse_bn_tensor(self.ver_conv, self.ver_bn)
        square_k, square_b = self.fuse_bn_tensor(self.square_conv, self.square_bn)

        self.add_to_square_kernel(square_k, hor_k)
        self.add_to_square_kernel(square_k, ver_k)

        return square_k, hor_b + ver_b + square_b

    def fuse_bn_tensor(self, conv, bn):
        pass

    def add_to_square_kernel(self, square_kernel, asym_kernel):
        pass

2.6 Conv2d与BN的融合(重参)

ACNet的Conv2d与BN的融合函数,具体如下图所示,我们对Conv2d的Kernel都乘以 γ σ B 2 + ϵ \frac{\gamma}{\sqrt{\sigma_B^2+\epsilon}} σB2+ϵ γ,对Conv2d的bias都加上 - μ B σ B 2 + ϵ \textbf{-}\frac{\mu_B}{\sqrt{\sigma_B^{2}+\epsilon}} -σB2+ϵ μB

剪枝与重参第八课:ACNet、DBB、RepVGG重参_第4张图片

import torch.nn as nn

class ACNet(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, padding = 0, dilation = 1,
                 groups = 1, padding_mode = 'zeros',
                 deploy = False, use_affine = True):
        super().__init__()
        self.deploy = deploy
        pass
    
    def forward(self, input):
        pass

    def get_equivalent_kernel_bias(self):
        hor_k, hor_b = self.fuse_bn_tensor(self.hor_conv, self.hor_bn)
        ver_k, ver_b = self.fuse_bn_tensor(self.ver_conv, self.ver_bn)
        square_k, square_b = self.fuse_bn_tensor(self.square_conv, self.square_bn)

        self.add_to_square_kernel(square_k, hor_k)
        self.add_to_square_kernel(square_k, ver_k)

        return square_k, hor_b + ver_b + square_b

    def fuse_bn_tensor(self, conv, bn):
        std = (bn.running_var + bn.eps).sqrt()
        t = (bn.weight / std).reshape(-1, 1, 1, 1)
        return conv.weight * t, bn.bias - bn.running_mean * bn.weight / std

    def add_to_square_kernel(self, square_kernel, asym_kernel):
        pass

2.7 Conv1x3Conv3x1Conv3x3的融合(重参)

ACNet的卷积权重融合函数,具体如下图所示:

剪枝与重参第八课:ACNet、DBB、RepVGG重参_第5张图片

import torch.nn as nn

class ACNet(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, padding = 0, dilation = 1,
                 groups = 1, padding_mode = 'zeros',
                 deploy = False, use_affine = True):
        super().__init__()
        self.deploy = deploy
        pass
    
    def forward(self, input):
        pass

    def swtich_to_deploy(self):
        pass

    def get_equivalent_kernel_bias(self):
        hor_k, hor_b = self.fuse_bn_tensor(self.hor_conv, self.hor_bn)
        ver_k, ver_b = self.fuse_bn_tensor(self.ver_conv, self.ver_bn)
        square_k, square_b = self.fuse_bn_tensor(self.square_conv, self.square_bn)

        self.add_to_square_kernel(square_k, hor_k)
        self.add_to_square_kernel(square_k, ver_k)

        return square_k, hor_b + ver_b + square_b

    def fuse_bn_tensor(self, conv, bn):
        std = (bn.running_var + bn.eps).sqrt()
        t = (bn.weight / std).reshape(-1, 1, 1, 1)
        return conv.weight * t, bn.bias - bn.running_mean * bn.weight / std

    def add_to_square_kernel(self, square_kernel, asym_kernel):
        asym_h = asym_kernel.size(2)
        asym_w = asym_kernel.size(3)

        square_h = square_kernel.size(2)
        square_w = square_kernel.size(2)

        square_kernel[:,
                      :,
                      square_h // 2 - asym_h // 2: square_h // 2 - asym_h // 2 + asym_h,
                      square_w // 2 - asym_w // 2: square_w // 2 - asym_w // 2 + asym_w] += asym_kernel

2.8 模型导出

ACNet网络的导出

import torch
import torch.nn as nn

class ACNet(nn.Module):
    pass

if __name__ == '__main__':
    
    dummy_input = torch.randn(1, 2, 62, 62)
    
    model = ACNet(in_channels=2, out_channels=8, kernel_size=3, padding=0, stride=1, deploy=False)
    
    model.eval()
    
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            nn.init.uniform_(module.running_mean, 0, 0.1)
            nn.init.uniform_(module.running_var, 0, 0.2)
            nn.init.uniform_(module.weight, 0, 0.3)
            nn.init.uniform_(module.bias, 0, 0.4)
    
    output = model(dummy_input)
    print(model)
    torch.onnx.export(model=model, args=dummy_input, f='./ACNet.onnx', verbose=False)

    model.swtich_to_deploy()
    deployout = model(dummy_input)
    print(model)
    torch.onnx.export(model=model, args=dummy_input, f='./ACNet-deploy.onnx', verbose=False)

2.9 完整示例代码

完整的示例代码如下

import torch
import torch.nn as nn

class ACNet(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, padding = 0, dilation = 1,
                 groups = 1, padding_mode = 'zeros',
                 deploy = False, use_affine = True):
        super().__init__()
        self.deploy = deploy
        if self.deploy:
            self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, kernel_size),
                                        stride=stride, padding=padding, dilation=dilation,
                                        groups=groups, bias=True, padding_mode=padding_mode)
        else:
            self.square_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, kernel_size),
                                         stride=stride, padding=padding, dilation=dilation,
                                         groups=groups, bias=True, padding_mode=padding_mode)
            
            self.square_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)

            if padding - kernel_size // 2 >= 0:
                self.crop = 0
                hor_padding = [padding - kernel_size // 2, padding]
                ver_padding = [padding, padding - kernel_size // 2]
            else:
                self.crop = kernel_size // 2 - padding
                hor_padding = [0, padding]
                ver_padding = [padding, 0]
            
            self.ver_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, 1),
                                      stride=stride, padding=ver_padding, dilation=dilation,
                                      groups=groups, bias=True, padding_mode=padding_mode)
            
            self.ver_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)

            self.hor_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, kernel_size),
                                      stride=stride, padding=hor_padding, dilation=dilation,
                                      groups=groups, bias=True, padding_mode=padding_mode)
            
            self.hor_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
    
    def forward(self, input):
        if self.deploy:
            return self.fused_conv(input)
        else:
            square_outputs = self.square_conv(input)
            square_outputs = self.square_bn(square_outputs)

            if self.crop > 0:
                ver_input = input[:, :, :, self.crop:-self.crop]
                hor_input = input[:, :, self.crop:-self.crop, :]
            else:
                ver_input = input
                hor_input = input
            vertical_outputs = self.ver_conv(ver_input)
            vertical_outputs = self.ver_bn(vertical_outputs)
            horizontal_outputs = self.hor_conv(hor_input)
            horizontal_outputs = self.hor_bn(horizontal_outputs)
            result = square_outputs + vertical_outputs + horizontal_outputs
            return result

    def swtich_to_deploy(self):
        self.deploy = True
        deploy_k, deploy_b = self.get_equivalent_kernel_bias()
        self.fused_conv = nn.Conv2d(in_channels=self.square_conv.in_channels,
                                    out_channels=self.square_conv.out_channels,
                                    kernel_size=self.square_conv.kernel_size,
                                    stride=self.square_conv.stride,
                                    padding=self.square_conv.padding,
                                    dilation=self.square_conv.dilation,
                                    groups=self.square_conv.groups,
                                    bias=True,
                                    padding_mode=self.square_conv.padding_mode)

        self.__delattr__('square_conv')
        self.__delattr__('square_bn')
        self.__delattr__('hor_conv')
        self.__delattr__('hor_bn')
        self.__delattr__('ver_conv')
        self.__delattr__('ver_bn')

        self.fused_conv.weight.data = deploy_k
        self.fused_conv.bias.data = deploy_b

    def get_equivalent_kernel_bias(self):
        hor_k, hor_b = self.fuse_bn_tensor(self.hor_conv, self.hor_bn)
        ver_k, ver_b = self.fuse_bn_tensor(self.ver_conv, self.ver_bn)
        square_k, square_b = self.fuse_bn_tensor(self.square_conv, self.square_bn)

        self.add_to_square_kernel(square_k, hor_k)
        self.add_to_square_kernel(square_k, ver_k)

        return square_k, hor_b + ver_b + square_b

    def fuse_bn_tensor(self, conv, bn):
        std = (bn.running_var + bn.eps).sqrt()
        t = (bn.weight / std).reshape(-1, 1, 1, 1)
        return conv.weight * t, bn.bias - bn.running_mean * bn.weight / std

    def add_to_square_kernel(self, square_kernel, asym_kernel):
        asym_h = asym_kernel.size(2)
        asym_w = asym_kernel.size(3)

        square_h = square_kernel.size(2)
        square_w = square_kernel.size(2)

        square_kernel[:,
                      :,
                      square_h // 2 - asym_h // 2: square_h // 2 - asym_h // 2 + asym_h,
                      square_w // 2 - asym_w // 2: square_w // 2 - asym_w // 2 + asym_w] += asym_kernel

if __name__ == '__main__':
    
    dummy_input = torch.randn(1, 2, 62, 62)
    
    model = ACNet(in_channels=2, out_channels=8, kernel_size=3, padding=0, stride=1, deploy=False)
    
    model.eval()
    
    for module in model.modules():
        if isinstance(module, nn.BatchNorm2d):
            nn.init.uniform_(module.running_mean, 0, 0.1)
            nn.init.uniform_(module.running_var, 0, 0.2)
            nn.init.uniform_(module.weight, 0, 0.3)
            nn.init.uniform_(module.bias, 0, 0.4)
    
    output = model(dummy_input)
    print(model)
    torch.onnx.export(model=model, args=dummy_input, f='./ACNet.onnx', verbose=False)

    model.swtich_to_deploy()
    deployout = model(dummy_input)
    print(model)
    torch.onnx.export(model=model, args=dummy_input, f='./ACNet-deploy.onnx', verbose=False)

在ACNet的构造函数中,我们需要指定输入通道数,输出通道数,卷积核大小,步长,填充,膨胀率,分组数,padding_mode等参数。deploy参数为True表示进行部署模式,即在部署时使用融合后的卷积层,而不使用分离的卷积层。(from chatGPT)

如果deploy为False,我们会构造出square_conv,square_bn,ver_conv,ver_bn,hor_conv和hor_bn这些子模块。在forward函数中,我们首先对输入进行正方形卷积,然后分别对其进行垂直卷积和水平卷积,并将结果加起来。

在重参数化中,我们需要将所有的卷积层与它们的批量归一化层融合起来。在ACNet中,我们需要对水平卷积和垂直卷积的卷积核和偏置进行融合,并将它们与正方形卷积的卷积核和偏置相加。

在函数get_equivalent_kernel_bias中,我们调用了函数fuse_bn_tensor来融合卷积层和批量归一化层。对于卷积层的权重,我们将其乘以批量归一化层的缩放因子,然后对卷积层的偏置进行修正。

在函数add_to_square_kernel中,我们将水平卷积和垂直卷积的卷积核加到正方形卷积的卷积核上。

导出的ONNX对比图如下:

剪枝与重参第八课:ACNet、DBB、RepVGG重参_第6张图片

总结

本次课程从并行分支结构技术出发,介绍了其用途和参数量大的问题,引出关于重参的概念,学习了ACNet的重参的操作,在推理时刻将原来的三次卷积修改为一次卷积操作。期待下次DBB的重参

你可能感兴趣的:(剪枝与重参,模型剪枝,模型重参数,深度学习)