Soft Filter Pruning (SFP)——允许更新Pruned Filters的Channel Pruning

 "Soft Filter Pruning for Accelerating Deep Convolutional Neural Networks"这篇文章首先强调了结构稀疏的优势,基于结构稀疏的channel pruning不需要特定存储格式和算法库的支持,能够充分利用成熟算法库或框架以运行剪枝后模型。文章同时提到传统的"hard filter pruning"依赖于预训练模型,并且直接删除pruned filters,结果导致随着模型容量的减少,推理精度急剧下降,且需要额外的、相对耗时的fine-tuning过程以恢复损失的精度。并且,直接删除的filters不再接受参数更新,显得简单粗糙,通常为了获得较大的剪枝率,需要多次迭代地实施剪枝操作。

Soft Filter Pruning (SFP)——允许更新Pruned Filters的Channel Pruning_第1张图片

如上图所示,文章为此提出了"soft filter pruning (SFP)"策略,允许模型从随机初始化开始(从预训练模型开始能获得更好的效果),并在每个epoch训练开始之前,将具有较小L2-norm的filters置零,然后更新所有filters(包括未剪枝和已剪枝filters),最终模型收敛以后再把一些不重要的filters(zero-filters)裁剪掉,从而获得模型容量较高、推理精度较高的正则化、剪枝结果。显然该策略类似于DSD(Dense-Sparsity-Dense)的正则化、剪枝策略,能够充分利用每个权重连接(无论是未剪枝和已剪枝的连接)的记忆作用,达到理想的正则化效果,并驱使既定比例的权重系数趋于稀疏化。

Soft Filter Pruning (SFP)——允许更新Pruned Filters的Channel Pruning_第2张图片

 

Soft Filter Pruning (SFP)——允许更新Pruned Filters的Channel Pruning_第3张图片

Soft Filter Pruning (SFP)策略如上图所示,主要分为四个步骤:1)filter selection:采用L2-norm(作为importance衡量准则)以及预先定义的剪枝率Pi,选择出一些不重要的filters;2)filter pruning:在每个epoch训练开始之前,在全局层面将不重要的filters置零,并允许置零的filters在当前epoch训练期间接受参数更新(soft-manner,不同于greedy selection),从而更好地平衡每个filter的贡献;3)reconstruction:通过反向传播更新所有filters,能够让pruned model按照与原始模型相同的容量接受参数更新。显然,置零filters对应的前向输出将趋于零,导致这些filters的梯度也趋于零,对应的参数更新幅度也会变得很小,而重要filters仍然接受正常的参数更新,从而达到理想的正则化效果;4)obtaining compact model:最终正则化、收敛以后,通过裁减掉zero filters可以获得结构紧凑的网络模型,同时达到理想的压缩与加速效果;

实验部分,文章在Cifar10、ImageNet2012数据集上对Resnet做了测试,获得了理想的剪枝效果,具体结果见文章。

总的来说,SFP剪枝策略首先通过正则化,降低了模型的过拟合风险,获得了饱含一定稀疏度的待剪枝模型,非常适合如Resnet、ResNext和VGG等含BN层或不含BN的CNN网络的结构性剪枝。另外,模型正则化之后,3D filters的重要性衡量准则可替换为Taylor Expansion Criteria等,或许能获得更好的剪枝效果,如下所示(Taylor Expansion Criteria与L2 norm):

for k, (n, m) in enumerate(model.named_modules()):
        # compute importance rank
        if isinstance(m, nn.Conv2d) and (k <= ML):
            if method == 'taylor':
                rank_temp_avg = torch.zeros(m.weight.data.shape[0]).float()
                for i in range(args.iters):
                    activation, grad = acts[i][index], grads[i][index]
                    rank_temp = torch.sum((activation * grad), dim = 0, keepdim=True).\
        				        sum(dim=2, keepdim=True).sum(dim=3, keepdim=True)[0, :, 0, 0].data
                    #rank_temp = torch.abs(rank_temp)
                    rank_temp = torch.abs(rank_temp / float(activation.size(0) * activation.size(2) * activation.size(3)))
                    rank_temp = rank_temp / torch.sqrt(torch.sum(rank_temp * rank_temp))
                    rank_temp_avg += rank_temp
                rank_temp_avg /= args.iters
                rank_temp = rank_temp_avg.cuda()
            elif method == 'l2':
                weight_torch = m.weight.data
                weight_vec = weight_torch.view(weight_torch.size()[0], -1)
                rank_temp = torch.norm(weight_vec, 2, 1)
            rank_dict[k] = rank_temp
            rank_list.append(rank_temp)
            total_pruned += rank_temp.shape[0]
            index += 1

论文地址:https://arxiv.org/abs/1808.06866

GitHub:https://github.com/he-y/soft-filter-pruning

你可能感兴趣的:(深度学习,模型压缩)