模型压缩工具Distiller-剪枝

1.distiller剪枝模块的使用

(1)distiller自带剪枝实例测试

        distiller自带一些测试实例如ResNet56+cifar-10,下面是对ResNet56+cifar-10的测试:

  •  测试前准备

  • yaml文件(注意:这里的yaml文件是coder配置好的,具体到自己的模型需要先对自己的model进行一次Sparsity Analysis,然后自己配置该文件) 在剪枝时所用到的yaml文件作用主要是配置了一些剪枝所需要的必要信息,比如下面ResNet56所需要用的yaml配置文件(路径:distiller/examples/pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml):
version: 1                                        # 版本
pruners:   
  filter_pruner_60:                               # 后面的60表示剪掉60%的Filters,如[16, 16, 3, 3]剪掉之后就是[7, 16, 3, 3]
    class: 'L1RankedStructureParameterPruner' # 表示所使用的算法,这里使用L1Rank
 group_type: Filters                              # 表示剪切类型,一般两种Filters/Channel
    desired_sparsity: 0.6                         # 剪掉60%的Filters
    weights: [                                    # 下面是一些具体的需要剪切的权值
      module.layer1.0.conv1.weight,
      module.layer1.1.conv1.weight,
      module.layer1.2.conv1.weight,
      module.layer1.3.conv1.weight,
      module.layer1.4.conv1.weight,
      module.layer1.5.conv1.weight,
      module.layer1.6.conv1.weight,
      module.layer1.7.conv1.weight,
      module.layer1.8.conv1.weight]

  filter_pruner_50:                                # 同上
    class: 'L1RankedStructureParameterPruner'
 group_type: Filters
    desired_sparsity: 0.5
    weights: [
      module.layer2.1.conv1.weight,
      module.layer2.2.conv1.weight,
      module.layer2.3.conv1.weight,
      module.layer2.4.conv1.weight,
      module.layer2.6.conv1.weight,
      module.layer2.7.conv1.weight]

  filter_pruner_10:                                 # 同上
    class: 'L1RankedStructureParameterPruner'
 group_type: Filters
    desired_sparsity: 0.1
    weights: [module.layer3.1.conv1.weight]

  filter_pruner_30:                                 # 同上
    class: 'L1RankedStructureParameterPruner'
 group_type: Filters
    desired_sparsity: 0.3
    weights: [
      module.layer3.2.conv1.weight,
      module.layer3.3.conv1.weight,
      module.layer3.5.conv1.weight,
      module.layer3.6.conv1.weight,
      module.layer3.7.conv1.weight,
      module.layer3.8.conv1.weight]


extensions:
  net_thinner:
      class: 'FilterRemover'
 thinning_func_str: remove_filters
      arch: 'resnet56_cifar' # 使用的网络
 dataset: 'cifar10' # 数据集

lr_schedulers:
   exp_finetuning_lr:
     class: ExponentialLR
     gamma: 0.95


policies:
  - pruner:
      instance_name: filter_pruner_60
    epochs: [0]

  - pruner:
      instance_name: filter_pruner_50
    epochs: [0]

  - pruner:
      instance_name: filter_pruner_30
    epochs: [0]

  - pruner:
      instance_name: filter_pruner_10
    epochs: [0]

  - extension:
      instance_name: net_thinner
    epochs: [0]

  - lr_scheduler:
      instance_name: exp_finetuning_lr
    starting_epoch: 10
    ending_epoch: 300
    frequency: 1

 

  • 准备ResNet-56需要的模型文件,可下载:

    https://s3-us-west-1.amazonaws.com/nndistiller/pruning_filters_for_efficient_convnets/checkpoint.resnet56_cifar_baseline.pth.tar
  • 剪枝

        找到compress_classifier.py文件,如下:

        模型压缩工具Distiller-剪枝_第1张图片

  

 $python3 compress_classifier.py -a=resnet56_cifar -p=50 ../../../data.cifar10 --epochs=70 --lr=0.1 --compress=../pruning_filters_for_efficient_convnets/resnet56_cifar_filter_rank.yaml --resume-from=checkpoint.resnet56_cifar_baseline.pth.tar --reset-optimizer --vs=0

             参数解释: -a 表示模型名称(这里是工具自带的模型名称,其他的如resnet32_cifar, resnet44_cifar, resnet56_cifar等等 cifar的模型代码文件位于distiller/models/cifar10/resnet_cifar.py)

              -p表示每隔多少打印一次

              ../../../data.cifar10是数据集路径

              --epochs 表示剪枝过后继续训练次数

              --compress 表示所用的‘策略’(compress_scheduler),一般是yaml文件的路径

              --resume-from 表示保存的模型的路径

              --reset-optimizer 如果设置此参数,那么start_epoch=0,将optimizer重置为SGD, 学习绿设置为传入的学习率

              --vs validation-split

              具体的其他参数参看distiller/apputils/image_classifier.py文件和distiller/quantization/range_linear.py文件以及github上参数解释。

              运行时会对模型进行剪枝,然后在测试集上测试,打印出top1和top5以及loss,运行结束后量化模型会保存在logs下。

 

(2)distiller对自己的模型剪枝

  • 具体流程:1.Sparsity Analysis 分析各层weight的sensitivity,即对模型各个部分的稀疏性和可以pruning的程度有个了解。 2. Yaml file create 创建一个属于自己模型的配置文件,里面各层的稀疏度是由第一阶段分析出来的sensitivity而来的 3.Thinning 对网络进行真正的剪枝。
  • 本人已经成功使用L1Rank算法在resnet-34(cifar-10和人脸),MobileFaceNet和行人重识别模型MGN上成功进行了剪枝。

你可能感兴趣的:(模型压缩)