鉴别力感知的通道剪枝——Discrimination-aware Channel Pruning

"Discrimination-aware Channel Pruning for Deep Neural Networks"这篇文章首先认为通道剪枝能够确保剪枝后模型与现有深度学习框架兼容,避免非规整的稀疏运算。其次基于训练的通道剪枝策略,需要在训练过程中施加稀疏约束或正则化约束,通常消耗较大的训练时间。另外基于最小化输出特征重建误差、layer-by-layer方式的剪枝策略(channel selection),容易忽视通道特征的鉴别力。因此文章提出了鉴别力感知的通道剪枝策略,即Discrimination-aware Channel Pruning(DCP),分别在fine-tune与剪枝(channel selection)阶段引入鉴别力感知的辅助loss,能够增强剪枝后所保留特征的鉴别能力。

鉴别力感知的通道剪枝——Discrimination-aware Channel Pruning_第1张图片

整体算法框架如上图所示:首先在第p阶段,结合Lp层的辅助损失L(p,s)与剪枝网络最终的分类损失Lf,更新第p-1阶段剪枝得到的模型,以恢复整体精度与特征的鉴别力;然后执行第p阶段的剪枝操作,以输出特征重建误差和辅助损失L(p,s)的联合loss作为优化目标,采用贪心策略完成channel selection,进而完成中的各个layer的剪枝。如此反复迭代,直至完成整体网络的剪枝目标。

类似于GoogleNet的训练,引入Lp层的辅助loss L(p,s),能够提升剪枝后所保留的通道对最终分类的贡献,即提升通道特征的鉴别能力(Discrimination power),因此称之为鉴别力感知的辅助loss(Discrimination-aware loss)。首先Lp层输出特征Op经过BN-ReLU-AveragePooling操作,降维成F(p,i)(表示batch索引,p表示通道索引,通道数为np),然后通过Softmax分类层输出预测得分(输出长度为m,表示分类数目)。记表示辅助分类层的权重矩阵,则L(p,s)可计算为:

在剪枝阶段,需要最小化的重建误差,表示为剪枝后模型输出特征与baseline特征之间的MSE(mean squared error):

为了确保剪枝后保留的通道具备一定的鉴别力,在剪枝阶段文章以重建误差与辅助loss L(p,s)相结合的联合loss作为优化目标,并提出了带不等式约束的优化问题

附加loss的加权系数设置在1.0,具体分析见Ablation实验结论。并且式中表示剪枝后需要保留的通道数(沿输入通道维度,选择最重要的输入特征):

鉴别力感知的通道剪枝——Discrimination-aware Channel Pruning_第2张图片

DCP(Discrimination-aware channel pruning)算法的具体实施如下所示,总共划分为P+1个阶段,每个阶段包括结合辅助loss的fine-tune、与结合辅助loss的channel selection两个子阶段:

鉴别力感知的通道剪枝——Discrimination-aware Channel Pruning_第3张图片

整体算法描述:在第p阶段,针对Lp层构建辅助loss L(s,p),并学习辅助分类层参数;然后结合L(s,p)与Lf更新剪枝后网络和辅助层的参数,起到fine-tune第p-1阶段获得的剪枝后模型(恢复整体精度)、以及提升特征鉴别力的作用;进而针对,逐层实施channel selection,完成第p阶段的剪枝操作。

Channel selection描述:文章采用贪心策略(greedy method)求解剪枝优化问题,即最小化重建误差与辅助loss的联合目标L。首先设置一个空集A,再按迭代方式选择输入特征中最重要的通道索引(对网络鉴别力存在实际贡献)加入到集合A中。在每次迭代时,首先计算联合目标LWj的梯度(Wj 表示第输入通道的卷积核参数),然后选择对梯度响应最大的通道加入到集合A中。由于联合loss中包含了鉴别力感知的辅助loss,因此所选择的输入特征通道具有较强的鉴别力。每往A中加入一个通道索引,便更新一次集合A中的参数,并驱使A补集中的参数为零:

Channel selection的终止条件(stopping conditions):给定一个预先设定的剪枝率,可以终止循环。但当经验数值(剪枝率)难以确定时,可以按照优化目标的绝对差值变化相对于迭代开始时的比值,自适应、动态地确定需要裁剪的通道数,也就是当比值很小、且低于容忍度时(容忍度越小,所保留的通道数越多,剪枝后模型性能越好)退出循环、完成channel selection:

最后,文章实验部分对比了DCP、DCP-Adapt、ThiNet、Channel Pruning(CP)、Network Slimming、WM、WM+以及Random DCP,评估数据集为Cifar10、ILSVRC-12以及LFW。其中DCP在ILSVRC-12上的表现优于ThiNet、CP、WM与WM+,不同剪枝率条件下,精度变化如下:

鉴别力感知的通道剪枝——Discrimination-aware Channel Pruning_第4张图片

通过观察Feature Map可知,所保留的特征通道具备更强的鉴别力:

鉴别力感知的通道剪枝——Discrimination-aware Channel Pruning_第5张图片

 

Paper地址:https://arxiv.org/abs/1810.11809

Tecent PocketFlow地址PocketFlow/learners/discr_channel_pruning/):https://github.com/Tencent/PocketFlow

 

你可能感兴趣的:(深度学习,模型压缩,优化加速)