利用L1范数的CNN模型剪枝

文章目录

  • 1.原理
  • 2.修改模型
  • 3.数据集
  • 4.代码实现
    • 4.1.正常训练
    • 4.2.稀疏训练
    • 4.3.剪枝
    • 4.4.微调
  • 参考文献

1.原理

缩放因子和稀疏性引起的惩罚。我们的想法是为每个通道引入一个缩放因子 γ,它与该通道的输出相乘。然后我们联合训练网络权重和这些缩放因子,并对后者进行稀疏正则化。最后,我们用小因子修剪那些通道,并对修剪后的网络进行微调。具体来说,我们方法的训练目标由下式给出
L = ∑ ( x , y ) l ( f ( x , W ) , y ) + λ ∑ γ ∈ Γ g ( γ ) — — — — ( 1 ) L=\sum_{(x,y)}l(f(x,W),y)+λ\sum_{\gamma\in\Gamma}g(\gamma)————(1) L=(x,y)l(f(x,W),y)+λγΓg(γ)(1)
其中(x,y)表示训练输入和目标,W 表示可训练权重,第一个求和项对应于 CNN 的正常训练损失,g(·)是稀疏性引起的尺度因子惩罚,λ 平衡这两个项。在我们的实验中,我们选择 g(s)=| s |,这被称为 L1 范数,广泛用于实现稀疏性。

​ 由于修剪通道本质上对应于删除该通道的所有传入和传出连接,因此我们可以直接获得一个狭窄的网络(见图 1),而无需借助任何特殊的稀疏计算包。缩放因子充当通道选择的代表。由于它们与网络权重联合优化,网络可以自动识别无关紧要的通道,从而安全地删除这些通道,而不会很大地影响泛化性能。
利用L1范数的CNN模型剪枝_第1张图片

图 1:我们将缩放因子(从批量归一化层复用)与卷积层中的每个通道相关联。在训练期间对这些缩放因子施加稀疏正则化以自动识别不重要的通道。具有小比例因子值(橙色)的通道将被修剪(左侧)。修剪后,我们获得紧凑模型(右侧),然后对其进行微调以达到与正常训练的完整网络相当(甚至更高)的精度。

利用 BN 层中的缩放因子。批量归一化已被大多数现代 CNN 用作实现快速收敛和更好泛化性能的标准方法。 BN 对激活进行归一化的方式促使我们设计一种简单有效的方法来合并通道缩放因子。特别是,BN 层使用小批量统计对内部激活进行归一化。设 zinzout 为一个 BN 层的输入和输出,B 表示当前的 minibatchBN 层进行如下变换:
z ^ = z i n − μ B ( σ B 2 + ϵ ) ; z o u t = γ z ^ + β — — — — ( 2 ) \hat{z} = \frac{z_{in}-\mu_{\mathcal{B}}}{\sqrt(\sigma^{2}_{\mathcal{B}}+\epsilon)};z_{out}=\gamma\hat{z}+\beta————(2) z^=( σB2+ϵ)zinμB;zout=γz^+β(2)
其中 μ B μ_{B} μB σ B σ_{B} σB B \mathcal{B} B 上输入激活的均值和标准差值,γβ 是可训练的仿射变换参数(尺度和位移),它提供了将归一化激活线性转换回任何尺度的可能性。

在归一化后会进行线性变换,那么当系数 γ 很小时候,对应的激活(Zout)会相应很小。这些响应很小的输出可以裁剪掉,这样就实现了 BN 层的通道剪枝。通过在损失函数中添加 γL1 正则约束,可以实现 γ 的稀疏化。公式(1)等号右边第一项是原始的损失函数,第二项是约束,其中 g(s) = |s|,λ 是正则系数,根据数据集调整实际训练的时候,就是在优化损失函数最小,依据梯度下降算法:

​ ′=∑′+∑′()=∑′+∑||′=∑′+∑∗()

所以只需要在反向传播时,在 BN 层权重乘以权重的符号函数输出和系数即可。
利用L1范数的CNN模型剪枝_第2张图片

图 2:网络瘦身过程流程图。虚线用于多通道/迭代方案。

​ 通常的做法是在卷积层之后插入一个 BN 层,并带有通道级缩放/移位参数。因此,我们可以直接利用 BN 层中的 γ 参数作为网络瘦身所需的缩放因子。它的巨大优势在于不会给网络增加开销。事实上,这可能也是我们学习有意义的通道剪枝缩放因子的最有效方式。 1)如果我们在没有 BN 层的 CNN 中添加缩放层,缩放因子的值对于评估通道的重要性没有意义,因为卷积层和缩放层都是线性变换。可以通过减小缩放因子值同时放大卷积层中的权重来获得相同的结果。 2)如果我们在 BN 层之前插入一个缩放层,缩放层的缩放效果将被 BN 中的归一化过程完全抵消。 3)如果我们在 BN 层之后插入缩放层,则每个通道有两个连续的缩放因子。

通道修剪和微调。在通道级稀疏诱导正则化下训练后,我们获得了一个模型,其中许多缩放因子接近于零(见图 1)。然后我们可以通过删除所有传入和传出连接以及相应的权重来修剪具有接近零缩放因子的通道。我们使用跨所有层的全局阈值修剪通道,该阈值定义为所有缩放因子值的某个百分位数。例如,我们通过选择百分比阈值为 70% 来修剪具有较低缩放因子的 70% 通道。通过这样做,我们获得了一个更紧凑的网络,具有更少的参数和运行时内存,以及更少的计算操作。

​ 当修剪率很高时,修剪可能会暂时导致一些精度损失。但这可以在很大程度上通过修剪后的网络上的微调过程得到补偿。在我们的实验中,经过微调的窄网络在很多情况下甚至可以达到比原始未剪枝网络更高的精度。

多程方案。我们还可以将所提出的方法从单程学习方案(使用稀疏正则化、修剪和微调进行训练)扩展到多程方案。具体来说,网络瘦身过程会导致网络变窄,我们可以再次应用整个训练过程来学习更紧凑的模型。这由图 2 中的虚线说明。实验结果表明,这种多通道方案可以在压缩率方面产生更好的结果。

处理跨层连接和预激活结构。上面介绍的网络瘦身过程可以直接应用于大多数普通的 CNN 架构,例如 AlexNetVGGNet。虽然将其应用于具有跨层连接和预激活设计(例如 ResNetDenseNet)的现代网络时需要进行一些调整。对于这些网络,一层的输出可以作为后续多个层的输入,其中在卷积层之前放置一个 BN 层。在这种情况下,稀疏性是在层的传入端实现的,即该层有选择地使用它接收到的信道子集。为了在测试时获得参数和计算节省,我们需要放置一个通道选择层来屏蔽我们已经确定的无关紧要的通道。

2.修改模型

​ 根据训练的数据集的需要,修改 yaml 文件中的 nc 即可。

​ 例如需要训练 coco 数据集中的车类目标检测,就将 coco 中的车类图像选出来(提取方法),[‘bicycle’, ‘car’, ‘motorcycle’, ‘bus’, truck] 一共 5 类目标。需要修剪的模型是 yolov5s ,就将 yolov5s.yaml 中的 nc 改为 5。

3.数据集

​ 从已有数据集中提取出自己想要的图像和 label,或者制作自己需要的数据集。我是从 coco 数据集中提取想要的图像,可以参考这里。

4.代码实现

4.1.正常训练

​ 训练 yolov5 模型作为基准,用 yolov5 的源码和权重文件。

4.2.稀疏训练

​ 反向传播时,在 BN 层的权重乘以权重的符号函数和学习稀疏系数。因为对 BN 层进行剪枝,所有的 BN 层都接在卷积层之后,这里选择没有 shortcut 的层进行剪枝。

​ 在 yolov5 项目原有的 train.py 文件中加入下面内容:

# Backward
loss.backward()
# scaler.scale(loss).backward()
# # ============================= sparsity training ========================== #
srtmp = opt.sr*(1 - 0.9*epoch/epochs)  # L1系数逐渐减小
if opt.st:
    ignore_bn_list = []
    for k, m in model.named_modules():
        if isinstance(m, Bottleneck):  # 有shortcut的Bottleneck层不剪枝
            if m.add:
                ignore_bn_list.append(k.rsplit(".", 2)[0] + ".cv1.bn")
                ignore_bn_list.append(k + '.cv1.bn')
                ignore_bn_list.append(k + '.cv2.bn')
                if isinstance(m, nn.BatchNorm2d) and (k not in ignore_bn_list):
                    m.weight.grad.data.add_(srtmp * torch.sign(m.weight.data))  # L1
                    m.bias.grad.data.add_(opt.sr*10 * torch.sign(m.bias.data))  # L1
# # ============================= sparsity training ========================== #

optimizer.step()
# scaler.step(optimizer)  # optimizer.step
# scaler.update()
optimizer.zero_grad()

4.3.剪枝

​ 将 BN 层前的卷积层对应通道的卷积核裁剪掉,BN 层后对应的特征图裁剪掉。

(1)将稀疏化的参数升序排列,根据修剪百分比设置参数的阈值

# =========================================== prune model ====================================#
model_list = {}
ignore_bn_list = []

for i, layer in model.named_modules():
    if isinstance(layer, Bottleneck):
        if layer.add:
            ignore_bn_list.append(i.rsplit(".",2)[0]+".cv1.bn")
            ignore_bn_list.append(i + '.cv1.bn')
            ignore_bn_list.append(i + '.cv2.bn')
    if isinstance(layer, nn.BatchNorm2d):
        if i not in ignore_bn_list:
            model_list[i] = layer
            print(i, layer)
model_list = {k:v for k,v in model_list.items() if k not in ignore_bn_list}

bn_weights = gather_bn_weights(model_list)
sorted_bn = torch.sort(bn_weights)[0]

# 避免剪掉所有channel的最高阈值(每个BN层的gamma的最大值的最小值即为阈值上限)
highest_thre = []
for bnlayer in model_list.values():
    highest_thre.append(bnlayer.weight.data.abs().max().item())
highest_thre = min(highest_thre)
# 找到highest_thre对应的下标对应的百分比
percent_limit = (sorted_bn == highest_thre).nonzero()[0, 0].item() / len(bn_weights)

print(f'Suggested Gamma threshold should be less than {highest_thre:.4f}.')
print(f'The corresponding prune ratio is {percent_limit:.3f}, but you can set higher.')
assert opt.percent < percent_limit, f"Prune ratio should less than {percent_limit}, otherwise it may cause error!!!"

# model_copy = deepcopy(model)
thre_index = int(len(sorted_bn) * opt.percent)
thre = sorted_bn[thre_index]
print(f'Gamma value that less than {thre:.4f} are set to zero!')
print("=" * 94)
print(f"|\t{'layer name':<25}{'|':<10}{'origin channels':<20}{'|':<10}{'remaining channels':<20}|")
remain_num = 0
modelstate = model.state_dict()

(2)根据指定的修剪阈值,设定掩码,将较小的参数置 0,其余不变

# ============================================================================== #
maskbndict = {}
for bnname, bnlayer in model.named_modules():
    if isinstance(bnlayer, nn.BatchNorm2d):
        bn_module = bnlayer
        mask = obtain_bn_mask(bn_module, thre)
        if bnname in ignore_bn_list:
            mask = torch.ones(bnlayer.weight.data.size()).cuda()
        maskbndict[bnname] = mask
        # print("mask:",mask)
        remain_num += int(mask.sum())
        bn_module.weight.data.mul_(mask)
        bn_module.bias.data.mul_(mask)
        # print("bn_module:", bn_module.bias)
        print(f"|\t{bnname:<25}{'|':<10}{bn_module.weight.data.size()[0]:<20}{'|':<10}{int(mask.sum()):<20}|")
print("=" * 94)
# print(maskbndict.keys())

pruned_model = ModelPruned(maskbndict=maskbndict, cfg=pruned_yaml, ch=3).cuda()
# Compatibility updates
for m in pruned_model.modules():
    if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model]:
        m.inplace = True  # pytorch 1.7.0 compatibility
    elif type(m) is Conv:
        m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatibility

from_to_map = pruned_model.from_to_map
pruned_model_state = pruned_model.state_dict()
assert pruned_model_state.keys() == modelstate.keys()
# ======================================================================================= #

(3)修剪参数,只保留非 0 参数,然后保存 pt 文件,代码太长不贴了

4.4.微调

​ 与 yolov5 正常训练基本一致,把权重文件换成修剪后的即可,最终 [email protected] 是0.64、权重文件大小是 5.4MB,而正常训练的 [email protected] 是0.636、权重文件大小是 14.4MB。
利用L1范数的CNN模型剪枝_第3张图片

参考文献

(1)https://arxiv.org/abs/1708.06519

(2)https://blog.csdn.net/IEEE_FELLOW/article/details/117236025

你可能感兴趣的:(cnn,剪枝,深度学习,目标检测)