MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn

目录

        • 稀疏卷积和 BN 的融合
          • 当前模块属于 SparseSequential 并且第一个子模块属于 SparseConvolution时,走165行的分支。
          • 当前模块属于 SparseBasicBlock
          • 当前模块属于 ReLU
      • 2D 卷积和 BN 的融合
        • 当前模块的子类属于 SyncBatchNorm 或不同维度的 BatchNorm 或 LazyBatchNorm
        • 当前模块的子类属于 Conv2d 或者 QuantConv2d
        • 不是以上两种情况
        • 1.7.4 融合后网络特点

稀疏卷积和 BN 的融合
  • 原始版本参考

https://github.com/traveller59/spconv/blob/master/example/fuse_bn_act.py

  • 核心思想还是标准卷积的融合思想

MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第1张图片

MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第2张图片

当前模块属于 SparseSequential 并且第一个子模块属于 SparseConvolution时,走165行的分支。

MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第3张图片
自定义的SparseConvolutionQuant层是继承SparseConvolution

  • 166行,取出conv,bn,relu。

  • 167行,如果当前模块属于 SparseSequential 并且其中的第一个模块属于 SparseConvolution 的模块,会进行 SparseConvolutionBN 融合操作。

MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第4张图片

  • 原来卷积权重形状由 [k1, k2, k3, in, out] 表示,代表卷积核大小的k1、k2、可k3,输入通道大小in,输出通道大小out。上图79行permute重新排序后变为 [out, k1, k2, k3, in]的形状,作为fuse_bn_weights方法的一个入参。

MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第5张图片

以第一个 SparseConvolutionQuant 为例,这里的入参的信息如下
MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第6张图片

56行,计算得到的 Ndim 和 permute 如下:

在这里插入图片描述

57行,将 conv_w_OKI 的维度重新排序得到 conv_w_OIK,[out, in, k1, k2, k3],形状为 [16, 5, 3, 3, 3]。分别代表输出通道数量16,输入通道数量5,卷积核大小3,3,3。

Pytorch 融合权重和偏置 (line 32)

如果这里稀疏卷积的偏置或者批量归一的权重或偏置为 None,就会为它们创一个 bn_rm 形状的张量。

  • reshape中一长串取值为[-1, 1, 1, 1, 1]

  • (bn_w * bn_var_rsqrt).shapetorch.Size([16]),reshape后为torch.Size([16, 1, 1, 1, 1])

  • conv_w_OIK形状 [16, 5, 3, 3, 3]

在这里插入图片描述

c o n v . w ∗ = c o n v . w ∗ b n . w b n . v a r + b n . e p s conv.w^* = \frac{conv.w * bn.w}{\sqrt{bn.var + bn.eps}} conv.w=bn.var+bn.eps conv.wbn.w

在这里插入图片描述

c o n v . b ∗ = c o n v . b − b n . m e a n b n . v a r + b n . e p s ∗ b n . w + b n . b conv.b^* = \frac{conv.b - bn.mean}{\sqrt{bn.var + bn.eps}} * bn.w + bn.b conv.b=bn.var+bn.eps conv.bbn.meanbn.w+bn.b

通过这样的计算之后,就将这个稀疏卷积与 bn 进行了融合

在这里插入图片描述

最后将卷积权重还原为之前的维度顺序,[out, k1, k2, k3, in],返回融合后的权重与偏置。

MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第7张图片

之后将权重的形状还原为稀疏卷积的原始权重维度,[k1, k2, k3, in, out]。

在这里插入图片描述

将融合后的稀疏卷积层与 ReLU 合并为一个 SparseSequential,最后再将之前 module 替换为这个新的稀疏序列就实现了 bn 融合的操作。

  • 替换前的SparseSequential中有c b r,其中的c如下。
SparseConvolutionQunat(
  (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
  (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
)
  • 替换后的SparseSequential中只剩c与r。
SparseSequential(
  (0): SparseConvolutionQunat(
    (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
    (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
  )
  (1): ReLU(inplace=True)
)

MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第8张图片

  • 这里的替换操作只要就是有一个递归函数来实现的,通过名字来寻找到没有子模块的子模块,将该子模块替换为融合后的模块。
  • 最终使用 setattr 方法,实现替换模块
当前模块属于 SparseBasicBlock

在这里插入图片描述

接下来以下面这个模块来介绍后续的处理

MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第9张图片
MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第10张图片

先将当前 module 的 forward 函数进行替换,这里的入参 self 是当前的 module,is_fuse_relu 为 False。

MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第11张图片
只是但对对out中的features做F.relu

MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第12张图片
原始的SparseBasicBlock

MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第13张图片
此时features已经是SparserTensor,不能直接F.relu。所以单独对x.features进行F.relu再封装乘SparserTensor类型。

MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第14张图片
SparseBasicBlock 原本的 forward 函数

通过对比 call 函数和原始的 forward 函数,可以发现以下几个不同点:

  • call 中没有 bn 的操作,因为之后会将稀疏卷积与 bn 进行融合。
  • 如果 ReLU 没有被融合,将 features 进行 ReLU 操作后,会通过 SparseConvTensor 构建了一个新的稀疏张量来存储结果。
  • 残差结构的加法计算时,也会对加法操作的结果通过一个新的稀疏张量来储存结果。

MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第15张图片

之后的操作就与之前的对一个稀疏卷积层和 bn 融合的方法相同,最后需要从这个 block 对象中删除 bn 属性。

当前模块属于 ReLU

在这里插入图片描述

仅将 ReLU 函数中的 inplace 参数更改为 False,这样返回的就是一个新的张量,不会更改原来的输入。

总结一下,这里就是将稀疏卷积网络中的所有 bn 与稀疏卷积进行融合,分为三个部分:

  • 稀疏序列的子模块包含了稀疏卷积的话,内部的稀疏卷积会与 bn 进行****融合,创建一个只包含融合后的稀疏卷积和 ReLU 的稀疏序列替换之前的序列。
  • 将 SparseBasicBlock 对象中** forward 函数进行****替换**,主要是将内部与 ReLU 和 residual 相关的部分进行了更改,然后将 bn 属性从该对象中移除。
  • 将所有的 ReLU 的 inplace 属性更****改为 False
  • 下方为初始稀疏卷积与量化后的对比
ModuleDict(
  (voxelize): Voxelization(voxel_size=[0.075, 0.075, 0.2], point_cloud_range=[-54.0, -54.0, -5.0, 54.0, 54.0, 3.0], max_num_points=10, max_voxels=(120000, 160000), deterministic=True)
  (backbone): SparseEncoder(
    (conv_input): SparseSequential(
      (0): SubMConv3d()
      (1): BatchNorm1d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (encoder_layers): SparseSequential(
      (encoder_layer1): SparseSequential(
        (0): SparseBasicBlock(
          (conv1): SubMConv3d()
          (bn1): BatchNorm1d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (conv2): SubMConv3d()
          (bn2): BatchNorm1d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (1): SparseBasicBlock(
          (conv1): SubMConv3d()
          (bn1): BatchNorm1d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (conv2): SubMConv3d()
          (bn2): BatchNorm1d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): SparseSequential(
          (0): SparseConv3d()
          (1): BatchNorm1d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
      )
      (encoder_layer2): SparseSequential(
        (0): SparseBasicBlock(
          (conv1): SubMConv3d()
          (bn1): BatchNorm1d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (conv2): SubMConv3d()
          (bn2): BatchNorm1d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (1): SparseBasicBlock(
          (conv1): SubMConv3d()
          (bn1): BatchNorm1d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (conv2): SubMConv3d()
          (bn2): BatchNorm1d(32, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): SparseSequential(
          (0): SparseConv3d()
          (1): BatchNorm1d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
      )
      (encoder_layer3): SparseSequential(
        (0): SparseBasicBlock(
          (conv1): SubMConv3d()
          (bn1): BatchNorm1d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (conv2): SubMConv3d()
          (bn2): BatchNorm1d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (1): SparseBasicBlock(
          (conv1): SubMConv3d()
          (bn1): BatchNorm1d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (conv2): SubMConv3d()
          (bn2): BatchNorm1d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (2): SparseSequential(
          (0): SparseConv3d()
          (1): BatchNorm1d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
      )
      (encoder_layer4): SparseSequential(
        (0): SparseBasicBlock(
          (conv1): SubMConv3d()
          (bn1): BatchNorm1d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (conv2): SubMConv3d()
          (bn2): BatchNorm1d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
        (1): SparseBasicBlock(
          (conv1): SubMConv3d()
          (bn1): BatchNorm1d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (conv2): SubMConv3d()
          (bn2): BatchNorm1d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
    )
    (conv_out): SparseSequential(
      (0): SparseConv3d()
      (1): BatchNorm1d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
)
SparseEncoder(
  (conv_input): SparseSequential(
    (0): SparseConvolutionQunat(
      (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
      (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
    )
    (1): ReLU()
  )
  (encoder_layers): SparseSequential(
    (encoder_layer1): SparseSequential(
      (0): SparseBasicBlock(
        (conv1): SparseConvolutionQunat(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
          (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
        )
        (conv2): SparseConvolutionQunat(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
          (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
        )
        (relu): ReLU()
        (quant_add): QuantAdd(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
        )
      )
      (1): SparseBasicBlock(
        (conv1): SparseConvolutionQunat(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
          (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
        )
        (conv2): SparseConvolutionQunat(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
          (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
        )
        (relu): ReLU()
        (quant_add): QuantAdd(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
        )
      )
      (2): SparseSequential(
        (0): SparseConvolutionQunat(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
          (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
        )
        (1): ReLU()
      )
    )
    (encoder_layer2): SparseSequential(
      (0): SparseBasicBlock(
        (conv1): SparseConvolutionQunat(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
          (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
        )
        (conv2): SparseConvolutionQunat(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
          (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
        )
        (relu): ReLU()
        (quant_add): QuantAdd(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
        )
      )
      (1): SparseBasicBlock(
        (conv1): SparseConvolutionQunat(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
          (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
        )
        (conv2): SparseConvolutionQunat(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
          (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
        )
        (relu): ReLU()
        (quant_add): QuantAdd(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
        )
      )
      (2): SparseSequential(
        (0): SparseConvolutionQunat(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
          (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
        )
        (1): ReLU()
      )
    )
    (encoder_layer3): SparseSequential(
      (0): SparseBasicBlock(
        (conv1): SparseConvolutionQunat(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
          (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
        )
        (conv2): SparseConvolutionQunat(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
          (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
        )
        (relu): ReLU()
        (quant_add): QuantAdd(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
        )
      )
      (1): SparseBasicBlock(
        (conv1): SparseConvolutionQunat(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
          (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
        )
        (conv2): SparseConvolutionQunat(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
          (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
        )
        (relu): ReLU()
        (quant_add): QuantAdd(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
        )
      )
      (2): SparseSequential(
        (0): SparseConvolutionQunat(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
          (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
        )
        (1): ReLU()
      )
    )
    (encoder_layer4): SparseSequential(
      (0): SparseBasicBlock(
        (conv1): SparseConvolutionQunat(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
          (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
        )
        (conv2): SparseConvolutionQunat(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
          (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
        )
        (relu): ReLU()
        (quant_add): QuantAdd(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
        )
      )
      (1): SparseBasicBlock(
        (conv1): SparseConvolutionQunat(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
          (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
        )
        (conv2): SparseConvolutionQunat(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
          (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
        )
        (relu): ReLU()
        (quant_add): QuantAdd(
          (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
        )
      )
    )
  )
  (conv_out): SparseSequential(
    (0): SparseConvolutionQunat(
      (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
      (_weight_quantizer): TensorQuantizer(8bit fake axis=4 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
    )
    (1): ReLU()
  )
)

2D 卷积和 BN 的融合

在这里插入图片描述

MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第16张图片

其实这里也是通过递归的方式将所有的 Conv 与 BN 进行融合

递归时可以分成三种情况:

当前模块的子类属于 SyncBatchNorm 或不同维度的 BatchNorm 或 LazyBatchNorm

MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第17张图片

  • fuse_conv_bn 递归函数

    • 56-57,俩变量置为 None
    • 遇到 convQuantConv2d,一定先走 69-71 行的分支。记录到 last_convlast_conv_name 中。
    • 再 for 循环,如果是符合是 bn,就融合。
  • 融合使用了 mmcv 的工具:https://github.com/open-mmlab/mmcv/blob/main/mmcv/cnn/utils/fuse_conv_bn.py

如果 BN 之前的不是 Conv2d 和 QuantConv2d 的话就不进行融合 (Conv1d,QuantConvTranspose2d)。

def find_conv1d(module, path):
    last_mod_name = None
    for name, submodule in module.named_children():
        if isinstance(submodule, (nn.Conv1d, QuantConvTranspose2d)):
            last_mod_name = ''.join([path, f'[{name}]']) if name.isdigit() else '.'.join([path, name])
        elif isinstance(submodule, (nn.modules.batchnorm._BatchNorm)):
            if last_mod_name:
                # print(last_mod_name)
                # print(''.join([path, f'[{name}]']) if name.isdigit() else '.'.join([path, name]))
                print(eval(last_mod_name))
                print(eval(''.join([path, f'[{name}]']) if name.isdigit() else '.'.join([path, name])), '\n')
        else:
            find_conv1d(submodule, ''.join([path, f'[{name}]']) if name.isdigit() else '.'.join([path, name]))

find_conv1d(model, 'model')

MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第18张图片

如果 BN 之前的是 Conv2d 和 QuantConv2d 的话,就会将这两层进行融合。

MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第19张图片

MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第20张图片

这里的计算与之前将 SparseConvolutionQunat 和 BN 融合类似。

c o n v . w ∗ = c o n v . w ∗ b n . w b n . v a r + b n . e p s conv.w^* = \frac{conv.w * bn.w}{\sqrt{bn.var + bn.eps}} conv.w=bn.var+bn.eps conv.wbn.w

c o n v . b ∗ = c o n v . b − b n . m e a n b n . v a r + b n . e p s ∗ b n . w + b n . b conv.b^* = \frac{conv.b - bn.mean}{\sqrt{bn.var + bn.eps}} * bn.w + bn.b conv.b=bn.var+bn.eps conv.bbn.meanbn.w+bn.b

之后将融合后的模块设置为上一个卷积对应的值,将当前的 bn 模块设置为 nn.Identity()。之后在生成 onnx 文件时会进行优化,将 nn.Identity() 层移除。

当前模块的子类属于 Conv2d 或者 QuantConv2d

MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第21张图片

更新 last_conv 和 last_conv_name 为当前的子模块和子模块的名称。如果下一个子模块为 BN 的话,就会与 last_conv 进行融合。

不是以上两种情况

MIT-BEVFusion系列七--量化3 稀疏卷积、普通卷积BN融合,fusebn_第22张图片

进行递归操作,继续遍历子模块的子模块。

1.7.4 融合后网络特点
  • model.encoders.camera.backbone.layer1[0] 为例
    • BatchNorm2d 被融合,用 I dentify 层取代原先 BatchNorm2d 的层。
Bottleneck(
  (conv1): QuantConv2d(
    64, 64, kernel_size=(1, 1), stride=(1, 1)
    (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
    (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
  )
  (bn1): Identity()
  (conv2): QuantConv2d(
    64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
    (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
  )
  (bn2): Identity()
  (conv3): QuantConv2d(
    64, 256, kernel_size=(1, 1), stride=(1, 1)
    (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
    (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
  )
  (bn3): Identity()
  (relu): ReLU(inplace=True)
  (downsample): Sequential(
    (0): QuantConv2d(
      64, 256, kernel_size=(1, 1), stride=(1, 1)
      (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
      (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
    )
    (1): Identity()
  )
  (residual_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
)
Bottleneck(
  (conv1): QuantConv2d(
    64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False
    (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
    (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
  )
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): QuantConv2d(
    64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False
    (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
    (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
  )
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): QuantConv2d(
    64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False
    (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
    (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
  )
  (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (downsample): Sequential(
    (0): QuantConv2d(
      64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False
      (_input_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
      (_weight_quantizer): TensorQuantizer(8bit fake axis=0 amax=dynamic calibrator=MaxCalibrator scale=1.0 quant)
    )
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (residual_quantizer): TensorQuantizer(8bit fake per-tensor amax=dynamic calibrator=HistogramCalibrator scale=1.0 quant)
)

你可能感兴趣的:(自动驾驶,python,算法,人工智能,目标检测)