手写AI推出的全新模型剪枝与重参课程。记录下个人学习笔记,仅供自己参考。
本次课程主要讲解ACNet的重参。
课程大纲可看下面的思维导图
DenseNet
VOVNet
Res2Net
PeleeNet
我们先简单来实现一个并行多分支结构的网络,如下图所示:
示例代码如下所示:
import torch
import torch.nn as nn
class SimpleParallelNet(nn.Module):
def __init__(self):
super().__init__()
self.branch1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.branch2 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=4, kernel_size=5, stride=1, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.branch3 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=4, kernel_size=7, stride=1, padding=3),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.conv = nn.Conv2d(12, 32, kernel_size=3, stride=1, padding=1)
self.fc = nn.Linear(8192, 10)
def forward(self, x):
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
x = torch.cat((x1, x2, x3), dim=1)
x = self.conv(x)
x = x.view(int(x.size(0)), -1)
x = self.fc(x)
return x
if __name__ == "__main__":
model = SimpleParallelNet()
dummy_input = torch.randn((1, 3, 32, 32))
model(dummy_input)
output_path = './test.onnx'
torch.onnx.export(model, dummy_input, output_path)
ACNet提出了一种替代卷积核的方法,用来增加模型对图像翻转
和旋转
的鲁棒性
这种方法在训练阶段中使用1x3卷积、3x1卷积和3x3卷积代替原来的3x3卷积,并将它们的计算结果进行融合得到卷积层的输出
这个过程,ACNet没有引入额外的超参数,也没有增加推理阶段的计算量。
在推理阶段,ACNet使用融合后的卷积核参数来初始化现有的网络,以提升模型的特征提取能力
ACNet的方法在实验中表现出强大的性能,尤其是在图像翻转和旋转的情况下,其鲁棒性远高于其它方法。这种方法的特定是引入了非对称卷积核,因此被称为Asymmetric Convolution
。ACNet的方法可以为各种网络带来涨点,因为它提升了模型对图像翻转和旋转的鲁棒性,同时不会增加推理阶段的计算量。
ACNet的初始化函数
import torch.nn as nn
class ACNet(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride, padding = 0, dilation = 1,
groups = 1, padding_mode = 'zeros',
deploy = False, use_affine = True):
super().__init__()
self.deploy = deploy
if self.deploy:
self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, kernel_size),
stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=True, padding_mode=padding_mode)
else:
self.square_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, kernel_size),
stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=True, padding_mode=padding_mode)
self.square_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
if padding - kernel_size // 2 >= 0:
self.crop = 0
hor_padding = [padding - kernel_size // 2, padding]
ver_padding = [padding, padding - kernel_size // 2]
else:
self.crop = kernel_size // 2 - padding
hor_padding = [0, padding]
ver_padding = [padding, 0]
self.ver_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, 1),
stride=stride, padding=ver_padding, dilation=dilation,
groups=groups, bias=True, padding_mode=padding_mode)
self.ver_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
self.hor_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, kernel_size),
stride=stride, padding=hor_padding, dilation=dilation,
groups=groups, bias=True, padding_mode=padding_mode)
self.hor_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
use_affine
用于控制是否将 γ \gamma γ和 β \beta β设置为可学习参数
ACNet的前向传播函数
import torch.nn as nn
class ACNet(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride, padding = 0, dilation = 1,
groups = 1, padding_mode = 'zeros',
deploy = False, use_affine = True):
super().__init__()
pass
def forward(self, input):
if self.deploy:
return self.fused_conv(input)
else:
square_outputs = self.square_conv(input)
square_outputs = self.square_bn(square_outputs)
if self.crop > 0:
ver_input = input[:, :, :, self.crop:-self.crop]
hor_input = input[:, :, self.crop:-self.crop, :]
else:
ver_input = input
hor_input = input
vertical_outputs = self.ver_conv(ver_input)
vertical_outputs = self.ver_bn(vertical_outputs)
horizontal_outputs = self.hor_conv(hor_input)
horizontal_outputs = self.hor_bn(horizontal_outputs)
result = square_outputs + vertical_outputs + horizontal_outputs
return result
ACNet的转换函数,将训练时3个卷积核转换为推理时的一个卷积核
import torch.nn as nn
class ACNet(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride, padding = 0, dilation = 1,
groups = 1, padding_mode = 'zeros',
deploy = False, use_affine = True):
super().__init__()
pass
def forward(self, input):
pass
def swtich_to_deploy(self):
self.deploy = True
deploy_k, deploy_b = self.get_equivalent_kernel_bias()
self.fused_conv = nn.Conv2d(in_channels=self.square_conv.in_channels,
out_channels=self.square_conv.out_channels,
kernel_size=self.square_conv.kernel_size,
stride=self.square_conv.stride,
padding=self.square_conv.padding,
dilation=self.square_conv.dilation,
groups=self.square_conv.groups,
bias=True,
padding_mode=self.square_conv.padding_mode)
self.__delattr__('square_conv')
self.__delattr__('square_bn')
self.__delattr__('hor_conv')
self.__delattr__('hor_bn')
self.__delattr__('ver_conv')
self.__delattr__('ver_bn')
self.fused_conv.weight.data = deploy_k
self.fused_conv.bias.data = deploy_b
def get_equivalent_kernel_bias(self):
pass
ACNet的重参具体实现函数
import torch.nn as nn
class ACNet(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride, padding = 0, dilation = 1,
groups = 1, padding_mode = 'zeros',
deploy = False, use_affine = True):
super().__init__()
pass
def forward(self, input):
pass
def swtich_to_deploy(self):
pass
def get_equivalent_kernel_bias(self):
hor_k, hor_b = self.fuse_bn_tensor(self.hor_conv, self.hor_bn)
ver_k, ver_b = self.fuse_bn_tensor(self.ver_conv, self.ver_bn)
square_k, square_b = self.fuse_bn_tensor(self.square_conv, self.square_bn)
self.add_to_square_kernel(square_k, hor_k)
self.add_to_square_kernel(square_k, ver_k)
return square_k, hor_b + ver_b + square_b
def fuse_bn_tensor(self, conv, bn):
pass
def add_to_square_kernel(self, square_kernel, asym_kernel):
pass
ACNet的Conv2d与BN的融合函数,具体如下图所示,我们对Conv2d的Kernel都乘以 γ σ B 2 + ϵ \frac{\gamma}{\sqrt{\sigma_B^2+\epsilon}} σB2+ϵγ,对Conv2d的bias都加上 - μ B σ B 2 + ϵ \textbf{-}\frac{\mu_B}{\sqrt{\sigma_B^{2}+\epsilon}} -σB2+ϵμB
import torch.nn as nn
class ACNet(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride, padding = 0, dilation = 1,
groups = 1, padding_mode = 'zeros',
deploy = False, use_affine = True):
super().__init__()
self.deploy = deploy
pass
def forward(self, input):
pass
def get_equivalent_kernel_bias(self):
hor_k, hor_b = self.fuse_bn_tensor(self.hor_conv, self.hor_bn)
ver_k, ver_b = self.fuse_bn_tensor(self.ver_conv, self.ver_bn)
square_k, square_b = self.fuse_bn_tensor(self.square_conv, self.square_bn)
self.add_to_square_kernel(square_k, hor_k)
self.add_to_square_kernel(square_k, ver_k)
return square_k, hor_b + ver_b + square_b
def fuse_bn_tensor(self, conv, bn):
std = (bn.running_var + bn.eps).sqrt()
t = (bn.weight / std).reshape(-1, 1, 1, 1)
return conv.weight * t, bn.bias - bn.running_mean * bn.weight / std
def add_to_square_kernel(self, square_kernel, asym_kernel):
pass
ACNet的卷积权重融合函数,具体如下图所示:
import torch.nn as nn
class ACNet(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride, padding = 0, dilation = 1,
groups = 1, padding_mode = 'zeros',
deploy = False, use_affine = True):
super().__init__()
self.deploy = deploy
pass
def forward(self, input):
pass
def swtich_to_deploy(self):
pass
def get_equivalent_kernel_bias(self):
hor_k, hor_b = self.fuse_bn_tensor(self.hor_conv, self.hor_bn)
ver_k, ver_b = self.fuse_bn_tensor(self.ver_conv, self.ver_bn)
square_k, square_b = self.fuse_bn_tensor(self.square_conv, self.square_bn)
self.add_to_square_kernel(square_k, hor_k)
self.add_to_square_kernel(square_k, ver_k)
return square_k, hor_b + ver_b + square_b
def fuse_bn_tensor(self, conv, bn):
std = (bn.running_var + bn.eps).sqrt()
t = (bn.weight / std).reshape(-1, 1, 1, 1)
return conv.weight * t, bn.bias - bn.running_mean * bn.weight / std
def add_to_square_kernel(self, square_kernel, asym_kernel):
asym_h = asym_kernel.size(2)
asym_w = asym_kernel.size(3)
square_h = square_kernel.size(2)
square_w = square_kernel.size(2)
square_kernel[:,
:,
square_h // 2 - asym_h // 2: square_h // 2 - asym_h // 2 + asym_h,
square_w // 2 - asym_w // 2: square_w // 2 - asym_w // 2 + asym_w] += asym_kernel
ACNet网络的导出
import torch
import torch.nn as nn
class ACNet(nn.Module):
pass
if __name__ == '__main__':
dummy_input = torch.randn(1, 2, 62, 62)
model = ACNet(in_channels=2, out_channels=8, kernel_size=3, padding=0, stride=1, deploy=False)
model.eval()
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
nn.init.uniform_(module.running_mean, 0, 0.1)
nn.init.uniform_(module.running_var, 0, 0.2)
nn.init.uniform_(module.weight, 0, 0.3)
nn.init.uniform_(module.bias, 0, 0.4)
output = model(dummy_input)
print(model)
torch.onnx.export(model=model, args=dummy_input, f='./ACNet.onnx', verbose=False)
model.swtich_to_deploy()
deployout = model(dummy_input)
print(model)
torch.onnx.export(model=model, args=dummy_input, f='./ACNet-deploy.onnx', verbose=False)
完整的示例代码如下:
import torch
import torch.nn as nn
class ACNet(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride, padding = 0, dilation = 1,
groups = 1, padding_mode = 'zeros',
deploy = False, use_affine = True):
super().__init__()
self.deploy = deploy
if self.deploy:
self.fused_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, kernel_size),
stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=True, padding_mode=padding_mode)
else:
self.square_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, kernel_size),
stride=stride, padding=padding, dilation=dilation,
groups=groups, bias=True, padding_mode=padding_mode)
self.square_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
if padding - kernel_size // 2 >= 0:
self.crop = 0
hor_padding = [padding - kernel_size // 2, padding]
ver_padding = [padding, padding - kernel_size // 2]
else:
self.crop = kernel_size // 2 - padding
hor_padding = [0, padding]
ver_padding = [padding, 0]
self.ver_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(kernel_size, 1),
stride=stride, padding=ver_padding, dilation=dilation,
groups=groups, bias=True, padding_mode=padding_mode)
self.ver_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
self.hor_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1, kernel_size),
stride=stride, padding=hor_padding, dilation=dilation,
groups=groups, bias=True, padding_mode=padding_mode)
self.hor_bn = nn.BatchNorm2d(num_features=out_channels, affine=use_affine)
def forward(self, input):
if self.deploy:
return self.fused_conv(input)
else:
square_outputs = self.square_conv(input)
square_outputs = self.square_bn(square_outputs)
if self.crop > 0:
ver_input = input[:, :, :, self.crop:-self.crop]
hor_input = input[:, :, self.crop:-self.crop, :]
else:
ver_input = input
hor_input = input
vertical_outputs = self.ver_conv(ver_input)
vertical_outputs = self.ver_bn(vertical_outputs)
horizontal_outputs = self.hor_conv(hor_input)
horizontal_outputs = self.hor_bn(horizontal_outputs)
result = square_outputs + vertical_outputs + horizontal_outputs
return result
def swtich_to_deploy(self):
self.deploy = True
deploy_k, deploy_b = self.get_equivalent_kernel_bias()
self.fused_conv = nn.Conv2d(in_channels=self.square_conv.in_channels,
out_channels=self.square_conv.out_channels,
kernel_size=self.square_conv.kernel_size,
stride=self.square_conv.stride,
padding=self.square_conv.padding,
dilation=self.square_conv.dilation,
groups=self.square_conv.groups,
bias=True,
padding_mode=self.square_conv.padding_mode)
self.__delattr__('square_conv')
self.__delattr__('square_bn')
self.__delattr__('hor_conv')
self.__delattr__('hor_bn')
self.__delattr__('ver_conv')
self.__delattr__('ver_bn')
self.fused_conv.weight.data = deploy_k
self.fused_conv.bias.data = deploy_b
def get_equivalent_kernel_bias(self):
hor_k, hor_b = self.fuse_bn_tensor(self.hor_conv, self.hor_bn)
ver_k, ver_b = self.fuse_bn_tensor(self.ver_conv, self.ver_bn)
square_k, square_b = self.fuse_bn_tensor(self.square_conv, self.square_bn)
self.add_to_square_kernel(square_k, hor_k)
self.add_to_square_kernel(square_k, ver_k)
return square_k, hor_b + ver_b + square_b
def fuse_bn_tensor(self, conv, bn):
std = (bn.running_var + bn.eps).sqrt()
t = (bn.weight / std).reshape(-1, 1, 1, 1)
return conv.weight * t, bn.bias - bn.running_mean * bn.weight / std
def add_to_square_kernel(self, square_kernel, asym_kernel):
asym_h = asym_kernel.size(2)
asym_w = asym_kernel.size(3)
square_h = square_kernel.size(2)
square_w = square_kernel.size(2)
square_kernel[:,
:,
square_h // 2 - asym_h // 2: square_h // 2 - asym_h // 2 + asym_h,
square_w // 2 - asym_w // 2: square_w // 2 - asym_w // 2 + asym_w] += asym_kernel
if __name__ == '__main__':
dummy_input = torch.randn(1, 2, 62, 62)
model = ACNet(in_channels=2, out_channels=8, kernel_size=3, padding=0, stride=1, deploy=False)
model.eval()
for module in model.modules():
if isinstance(module, nn.BatchNorm2d):
nn.init.uniform_(module.running_mean, 0, 0.1)
nn.init.uniform_(module.running_var, 0, 0.2)
nn.init.uniform_(module.weight, 0, 0.3)
nn.init.uniform_(module.bias, 0, 0.4)
output = model(dummy_input)
print(model)
torch.onnx.export(model=model, args=dummy_input, f='./ACNet.onnx', verbose=False)
model.swtich_to_deploy()
deployout = model(dummy_input)
print(model)
torch.onnx.export(model=model, args=dummy_input, f='./ACNet-deploy.onnx', verbose=False)
在ACNet的构造函数中,我们需要指定输入通道数,输出通道数,卷积核大小,步长,填充,膨胀率,分组数,padding_mode等参数。deploy参数为True表示进行部署模式,即在部署时使用融合后的卷积层,而不使用分离的卷积层。(from chatGPT)
如果deploy为False,我们会构造出square_conv,square_bn,ver_conv,ver_bn,hor_conv和hor_bn这些子模块。在forward函数中,我们首先对输入进行正方形卷积,然后分别对其进行垂直卷积和水平卷积,并将结果加起来。
在重参数化中,我们需要将所有的卷积层与它们的批量归一化层融合起来。在ACNet中,我们需要对水平卷积和垂直卷积的卷积核和偏置进行融合,并将它们与正方形卷积的卷积核和偏置相加。
在函数get_equivalent_kernel_bias
中,我们调用了函数fuse_bn_tensor
来融合卷积层和批量归一化层。对于卷积层的权重,我们将其乘以批量归一化层的缩放因子,然后对卷积层的偏置进行修正。
在函数add_to_square_kernel
中,我们将水平卷积和垂直卷积的卷积核加到正方形卷积的卷积核上。
导出的ONNX对比图如下:
本次课程从并行分支结构技术出发,介绍了其用途和参数量大的问题,引出关于重参的概念,学习了ACNet的重参的操作,在推理时刻将原来的三次卷积修改为一次卷积操作。期待下次DBB的重参