Distiller:神经网络压缩研究框架

Distiller是由Intel AI Lab维护的基于PyTorch的开源神经网络压缩框架。主要包括:

  • 用于集成剪枝(pruning),正则化(regularization)和量化(quantization )算法的框架。
  • 一套用于分析和评估压缩性能的工具。
  • 现有技术压缩算法的示例实现。

这算是目前我发现的最完整的压缩框架了,比较适合科研工作。下面简单说一下安装和使用。

NervanaSystems/distiller
Distiller Documentation

安装

创建虚拟环境

Distiller是基于python的开源框架,为了不与其他工作冲突,最好提前创建一个虚拟环境。官方给的教程是使用virtualenv的,这个功能简单一些,这里有安装教程。目前流行的python包管理是conda(我用的也是),用的顺手就行。之后activate到创建的环境。

注意: - Distiller 仅在 Ubuntu 16.04 LTS中 Python 3.5环境下测试,并且默认使用了GPU,如果不使用GPU代码可能需要微小的调整。

克隆仓库并安装依赖

$ git clone https://github.com/NervanaSystems/distiller.git

Clone Distiller的仓库。

$ pip3 install -r requirements.txt

安装所需要的依赖库。
这样Distiller就安装好了。

数据集

分类数据集的整理需要按照文档给定的方式(这里)。

distiller
  examples
    classifier_compression
data.imagenet/
    train/
    val/
data.cifar10/
    cifar-10-batches-py/
        batches.meta
        data_batch_1
        data_batch_2
        data_batch_3
        data_batch_4
        data_batch_5
        readme.html
        test_batch

当然还是比较常规的,建立软连接即可。

使用

仓库提供了大量的实例(distiller/examples/),本文以distiller/examples/classifier_compression/compress_classifier.py为例。

命令行参数

这个文件包含了一整套压缩框架,参数帮助可以调用以下命令:

$ python3 compress_classifier.py --help

比如我们输入以下指令:

$ time python3 compress_classifier.py -a alexnet --lr 0.005 -p 50 ../../../data.imagenet -j 44 --epochs 90 --pretrained --compress=../sensitivity-pruning/alexnet.schedule_sensitivity.yaml

Parameters:
 +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
 |    | Name                      | Shape            |   NNZ (dense) |   NNZ (sparse) |   Cols (%) |   Rows (%) |   Ch (%) |   2D (%) |   3D (%) |   Fine (%) |     Std |     Mean |   Abs-Mean |
 |----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------|
 |  0 | features.module.0.weight  | (64, 3, 11, 11)  |         23232 |          13411 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   42.27359 | 0.14391 | -0.00002 |    0.08805 |
 |  1 | features.module.3.weight  | (192, 64, 5, 5)  |        307200 |         115560 |    0.00000 |    0.00000 |  0.00000 |  1.91243 |  0.00000 |   62.38281 | 0.04703 | -0.00250 |    0.02289 |
 |  2 | features.module.6.weight  | (384, 192, 3, 3) |        663552 |         256565 |    0.00000 |    0.00000 |  0.00000 |  6.18490 |  0.00000 |   61.33445 | 0.03354 | -0.00184 |    0.01803 |
 |  3 | features.module.8.weight  | (256, 384, 3, 3) |        884736 |         315065 |    0.00000 |    0.00000 |  0.00000 |  6.96411 |  0.00000 |   64.38881 | 0.02646 | -0.00168 |    0.01422 |
 |  4 | features.module.10.weight | (256, 256, 3, 3) |        589824 |         186938 |    0.00000 |    0.00000 |  0.00000 | 15.49225 |  0.00000 |   68.30614 | 0.02714 | -0.00246 |    0.01409 |
 |  5 | classifier.1.weight       | (4096, 9216)     |      37748736 |        3398881 |    0.00000 |    0.21973 |  0.00000 |  0.21973 |  0.00000 |   90.99604 | 0.00589 | -0.00020 |    0.00168 |
 |  6 | classifier.4.weight       | (4096, 4096)     |      16777216 |        1782769 |    0.21973 |    3.46680 |  0.00000 |  3.46680 |  0.00000 |   89.37387 | 0.00849 | -0.00066 |    0.00263 |
 |  7 | classifier.6.weight       | (1000, 4096)     |       4096000 |         994738 |    3.36914 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   75.71440 | 0.01718 |  0.00030 |    0.00778 |
 |  8 | Total sparsity:           | -                |      61090496 |        7063928 |    0.00000 |    0.00000 |  0.00000 |  0.00000 |  0.00000 |   88.43694 | 0.00000 |  0.00000 |    0.00000 |
 +----+---------------------------+------------------+---------------+----------------+------------+------------+----------+----------+----------+------------+---------+----------+------------+
 2018-04-04 21:30:52,499 - Total sparsity: 88.44

 2018-04-04 21:30:52,499 - --- validate (epoch=89)-----------
 2018-04-04 21:30:52,499 - 128116 samples (256 per mini-batch)
 2018-04-04 21:31:04,646 - Epoch: [89][   50/  500]    Loss 2.175988    Top1 51.289063    Top5 74.023438
 2018-04-04 21:31:06,427 - Epoch: [89][  100/  500]    Loss 2.171564    Top1 51.175781    Top5 74.308594
 2018-04-04 21:31:11,432 - Epoch: [89][  150/  500]    Loss 2.159347    Top1 51.546875    Top5 74.473958
 2018-04-04 21:31:14,364 - Epoch: [89][  200/  500]    Loss 2.156857    Top1 51.585938    Top5 74.568359
 2018-04-04 21:31:18,381 - Epoch: [89][  250/  500]    Loss 2.152790    Top1 51.707813    Top5 74.681250
 2018-04-04 21:31:22,195 - Epoch: [89][  300/  500]    Loss 2.149962    Top1 51.791667    Top5 74.755208
 2018-04-04 21:31:25,508 - Epoch: [89][  350/  500]    Loss 2.150936    Top1 51.827009    Top5 74.767857
 2018-04-04 21:31:29,538 - Epoch: [89][  400/  500]    Loss 2.150853    Top1 51.781250    Top5 74.763672
 2018-04-04 21:31:32,842 - Epoch: [89][  450/  500]    Loss 2.150156    Top1 51.828125    Top5 74.821181
 2018-04-04 21:31:35,338 - Epoch: [89][  500/  500]    Loss 2.150417    Top1 51.833594    Top5 74.817187
 2018-04-04 21:31:35,357 - ==> Top1: 51.838    Top5: 74.817    Loss: 2.150

 2018-04-04 21:31:35,364 - Saving checkpoint
 2018-04-04 21:31:39,251 - --- test ---------------------
 2018-04-04 21:31:39,252 - 50000 samples (256 per mini-batch)
 2018-04-04 21:31:51,512 - Test: [   50/  195]    Loss 1.487607    Top1 63.273438    Top5 85.695312
 2018-04-04 21:31:55,015 - Test: [  100/  195]    Loss 1.638043    Top1 60.636719    Top5 83.664062
 2018-04-04 21:31:58,732 - Test: [  150/  195]    Loss 1.833214    Top1 57.619792    Top5 80.447917
 2018-04-04 21:32:01,274 - ==> Top1: 56.606    Top5: 79.446    Loss: 1.893

具体来看:

$ time python3 compress_classifier.py -a alexnet --lr 0.005 -p 50 ../../../data.imagenet -j 44 --epochs 90 --pretrained --compress=../sensitivity-pruning/alexnet.schedule_sensitivity.yaml

在这个例子中,我们对预训练的AlexNet网络进行剪枝,并使用以下参数:

  • 0.005的学习率
  • 每50个mini-batch输出一次信息
  • 使用44核线程加载数据(取决于个人服务器)
  • 训练90个epoch
  • 剪枝策略在alexnet.schedule_sensitivity.yaml中提供
  • 输出到日志到logs

示例

Distiller附带了几个使用compress_classifier.py的示例,配置比较简单,可以将配置文件(YAML)直接作为命令行compress参数输入。具体可以浏览examples文件夹。同时网上附带了几个训练好的模型,不过大部分是剪枝方面的。

剪枝敏感度分析

Distiller支持element-wise和filter-wise的剪枝灵敏度分析。在这两种情况下,L1-norm用于对要修剪的元素或过滤器进行排序。 例如,当运行filter-wise剪枝灵敏度分析时,计算每层的权重张量的滤波器的L1范数,并将底部x%设置为零。
分析过程很长,因为目前使用整个测试数据集来评估每个权重张量的每个剪枝级别的准确性。结果输出为CSV文件(sensitivity.csv)和PNG文件(sensitivity.png)。该实现位于distiller/sensitivity.py中。这里也有一个jupyter的notebook。

训练后量化

Distiller支持训练模块的训练后量化,无需重新训练(使用基于范围的线性量化)。因此,任何模型(无论是否剪枝)都可以量化。要调用训练后量化,请使用--quantize-eval--evaluate。 其他参数可用于控制量化参数:

Arguments controlling quantization at evaluation time("post-training quantization"):
  --quantize-eval, --qe
                        
                        Applicable only if --evaluate is also set
  --qe-mode QE_MODE, --qem QE_MODE
                        Linear quantization mode. Choices: asym_s | asym_u |
                        sym
  --qe-bits-acts NUM_BITS, --qeba NUM_BITS
                        Number of bits for quantization of activations
  --qe-bits-wts NUM_BITS, --qebw NUM_BITS
                        Number of bits for quantization of weights
  --qe-bits-accum NUM_BITS
                        Number of bits for quantization of the accumulator
  --qe-clip-acts, --qeca
                        Enable clipping of activations using min/max values
                        averaging over batch
  --qe-no-clip-layers LAYER_NAME [LAYER_NAME ...], --qencl LAYER_NAME [LAYER_NAME ...]
                        List of fully-qualified layer names for which not to
                        clip activations. Applicable only if --qe-clip-acts is
                        also set
  --qe-per-channel, --qepc
                        Enable per-channel quantization of weights (per output channel)

下面是一个resnet18的量化示例:

$ python3 compress_classifier.py -a resnet18 ../../../data.imagenet  --pretrained --quantize-eval --evaluate

进行量化的模型将被转储到运行目录中。 它将包含量化的模型参数(数据类型仍然是FP32,但值将是整数)。计算出的量化参数(比例和零点)也存储在每个量化层中。

总结

本文简单介绍了Distiller的安装和示例使用,如何将实验的需求结合到这里面还需要更深入地研究一下。

你可能感兴趣的:(python,pytorch,机器学习)