paddleSlim(一)剪裁

目录

1.简介

2.卷积Filter剪裁

2.1paddleSlim的API

 2.2示例

 2.2.1剪裁前后对比


1.简介

PaddleSlim是一个专注于深度学习模型压缩的工具库,提供剪裁、量化、蒸馏、和模型结构搜索等模型压缩策略,帮助用户快速实现模型的小型化。

2.卷积Filter剪裁

对卷积网络的通道进行一次剪裁。剪裁一个卷积层的通道,是指剪裁该卷积层输出的通道。卷积层的权重形状为 [output_channel, input_channel, kernel_size, kernel_size] ,通过剪裁该权重的第一纬度达到剪裁输出通道数的目的。

实际剪裁时要考虑到每层通道的敏感度,一般剪裁后要在验证集上测试精度得到敏感度,敏感度低的剪裁掉来压缩模型。

2.1paddleSlim的API

在paddleslim中关于动态图的剪裁接口主要有三个:

  • L1NormFilterPruner        该剪裁器按 Filters 的 l1-norm 统计值对单个卷积层内的 Filters 的重要性进行排序,并按指定比例剪裁掉相对不重要的 Filters 。对 Filters 的剪裁等价于剪裁卷积层的输出通道数。
  • L2NormFilterPruner        该剪裁器按 Filters 的 l2-norm 统计值对单个卷积层内的 Filters 的重要性进行排序,并按指定比例剪裁掉相对不重要的 Filters 。对 Filters 的剪裁等价于剪裁卷积层的输出通道数。
  • FPGMFilterPruner          该剪裁器按论文 Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration _ 中的统计方法对单个卷积层内的 Filters 的重要性进行排序,并按指定比例剪裁掉相对不重要的 Filters 。对 Filters 的剪裁等价于剪裁卷积层的输出通道数。

 2.2示例

导入各种包

import paddle
import paddle.vision.models as models
from paddle.static import InputSpec as Input
from paddle.vision.datasets import Cifar10
import paddle.vision.transforms as T
from paddleslim.dygraph import L1NormFilterPruner

网络定义和数据集加载 

net = models.mobilenet_v1()
inputs = Input(shape=[None, 3, 32, 32], dtype='float32', name='image')
labels = Input(shape=[None, 1], dtype='int64', name='label')
optmizer = paddle.optimizer.Momentum(learning_rate=0.1, parameters=net.parameters())
model = paddle.Model(net, inputs, labels)
model.prepare(
    optimizer=optmizer, 
    loss=paddle.nn.CrossEntropyLoss(),
    metrics=paddle.metric.Accuracy(topk=(1, 5))
)

transforms = T.Compose([
    T.Transpose(),
    T.Normalize([127.5], [127.5])
])

train_dataset = Cifar10(mode='train', transform=transforms)
test_dataset = Cifar10(mode='train', transform=transforms)

训练模型

model.fit(train_dataset, epochs=2, batch_size=128, verbose=1)

 计算剪裁之前的模型相关信息。使用paddle.flops函数

flops = paddle.flops(net, input_size=[1, 3, 32, 32], print_detail=True)

结果会返回模型的详细参数量Params和Flops,此外在Layer Name中的名字就是调用剪裁器pruner要给出剪裁参数名。

+-----------------------+-----------------+-----------------+---------+---------+
|       Layer Name      |   Input Shape   |   Output Shape  |  Params |  Flops  |
+-----------------------+-----------------+-----------------+---------+---------+
|        conv2d_0       |  [1, 3, 32, 32] | [1, 32, 16, 16] |   864   |  221184 |
|     batch_norm2d_0    | [1, 32, 16, 16] | [1, 32, 16, 16] |   128   |  16384  |
|        re_lu_0        | [1, 32, 16, 16] | [1, 32, 16, 16] |    0    |    0    |
|        conv2d_1       | [1, 32, 16, 16] | [1, 32, 16, 16] |   288   |  73728  |
|     batch_norm2d_1    | [1, 32, 16, 16] | [1, 32, 16, 16] |   128   |  16384  |
|        re_lu_1        | [1, 32, 16, 16] | [1, 32, 16, 16] |    0    |    0    |
|        conv2d_2       | [1, 32, 16, 16] | [1, 64, 16, 16] |   2048  |  524288 |
|     batch_norm2d_2    | [1, 64, 16, 16] | [1, 64, 16, 16] |   256   |  32768  |
|        re_lu_2        | [1, 64, 16, 16] | [1, 64, 16, 16] |    0    |    0    |
|        conv2d_3       | [1, 64, 16, 16] |  [1, 64, 8, 8]  |   576   |  36864  |
|     batch_norm2d_3    |  [1, 64, 8, 8]  |  [1, 64, 8, 8]  |   256   |   8192  |
|        re_lu_3        |  [1, 64, 8, 8]  |  [1, 64, 8, 8]  |    0    |    0    |
|        conv2d_4       |  [1, 64, 8, 8]  |  [1, 128, 8, 8] |   8192  |  524288 |
|     batch_norm2d_4    |  [1, 128, 8, 8] |  [1, 128, 8, 8] |   512   |  16384  |
|        re_lu_4        |  [1, 128, 8, 8] |  [1, 128, 8, 8] |    0    |    0    |
|        conv2d_5       |  [1, 128, 8, 8] |  [1, 128, 8, 8] |   1152  |  73728  |
|     batch_norm2d_5    |  [1, 128, 8, 8] |  [1, 128, 8, 8] |   512   |  16384  |
|        re_lu_5        |  [1, 128, 8, 8] |  [1, 128, 8, 8] |    0    |    0    |
|        conv2d_6       |  [1, 128, 8, 8] |  [1, 128, 8, 8] |  16384  | 1048576 |
|     batch_norm2d_6    |  [1, 128, 8, 8] |  [1, 128, 8, 8] |   512   |  16384  |
|        re_lu_6        |  [1, 128, 8, 8] |  [1, 128, 8, 8] |    0    |    0    |
|        conv2d_7       |  [1, 128, 8, 8] |  [1, 128, 4, 4] |   1152  |  18432  |
|     batch_norm2d_7    |  [1, 128, 4, 4] |  [1, 128, 4, 4] |   512   |   4096  |
|        re_lu_7        |  [1, 128, 4, 4] |  [1, 128, 4, 4] |    0    |    0    |
|        conv2d_8       |  [1, 128, 4, 4] |  [1, 256, 4, 4] |  32768  |  524288 |
|     batch_norm2d_8    |  [1, 256, 4, 4] |  [1, 256, 4, 4] |   1024  |   8192  |
|        re_lu_8        |  [1, 256, 4, 4] |  [1, 256, 4, 4] |    0    |    0    |
|        conv2d_9       |  [1, 256, 4, 4] |  [1, 256, 4, 4] |   2304  |  36864  |
|     batch_norm2d_9    |  [1, 256, 4, 4] |  [1, 256, 4, 4] |   1024  |   8192  |
|        re_lu_9        |  [1, 256, 4, 4] |  [1, 256, 4, 4] |    0    |    0    |
|       conv2d_10       |  [1, 256, 4, 4] |  [1, 256, 4, 4] |  65536  | 1048576 |
|    batch_norm2d_10    |  [1, 256, 4, 4] |  [1, 256, 4, 4] |   1024  |   8192  |
|        re_lu_10       |  [1, 256, 4, 4] |  [1, 256, 4, 4] |    0    |    0    |
|       conv2d_11       |  [1, 256, 4, 4] |  [1, 256, 2, 2] |   2304  |   9216  |
|    batch_norm2d_11    |  [1, 256, 2, 2] |  [1, 256, 2, 2] |   1024  |   2048  |
|        re_lu_11       |  [1, 256, 2, 2] |  [1, 256, 2, 2] |    0    |    0    |
|       conv2d_12       |  [1, 256, 2, 2] |  [1, 512, 2, 2] |  131072 |  524288 |
|    batch_norm2d_12    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_12       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_13       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   4608  |  18432  |
|    batch_norm2d_13    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_13       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_14       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |  262144 | 1048576 |
|    batch_norm2d_14    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_14       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_15       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   4608  |  18432  |
|    batch_norm2d_15    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_15       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_16       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |  262144 | 1048576 |
|    batch_norm2d_16    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_16       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_17       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   4608  |  18432  |
|    batch_norm2d_17    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_17       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_18       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |  262144 | 1048576 |
|    batch_norm2d_18    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_18       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_19       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   4608  |  18432  |
|    batch_norm2d_19    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_19       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_20       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |  262144 | 1048576 |
|    batch_norm2d_20    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_20       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_21       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   4608  |  18432  |
|    batch_norm2d_21    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_21       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_22       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |  262144 | 1048576 |
|    batch_norm2d_22    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_22       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_23       |  [1, 512, 2, 2] |  [1, 512, 1, 1] |   4608  |   4608  |
|    batch_norm2d_23    |  [1, 512, 1, 1] |  [1, 512, 1, 1] |   2048  |   1024  |
|        re_lu_23       |  [1, 512, 1, 1] |  [1, 512, 1, 1] |    0    |    0    |
|       conv2d_24       |  [1, 512, 1, 1] | [1, 1024, 1, 1] |  524288 |  524288 |
|    batch_norm2d_24    | [1, 1024, 1, 1] | [1, 1024, 1, 1] |   4096  |   2048  |
|        re_lu_24       | [1, 1024, 1, 1] | [1, 1024, 1, 1] |    0    |    0    |
|       conv2d_25       | [1, 1024, 1, 1] | [1, 1024, 1, 1] |   9216  |   9216  |
|    batch_norm2d_25    | [1, 1024, 1, 1] | [1, 1024, 1, 1] |   4096  |   2048  |
|        re_lu_25       | [1, 1024, 1, 1] | [1, 1024, 1, 1] |    0    |    0    |
|       conv2d_26       | [1, 1024, 1, 1] | [1, 1024, 1, 1] | 1048576 | 1048576 |
|    batch_norm2d_26    | [1, 1024, 1, 1] | [1, 1024, 1, 1] |   4096  |   2048  |
|        re_lu_26       | [1, 1024, 1, 1] | [1, 1024, 1, 1] |    0    |    0    |
| adaptive_avg_pool2d_0 | [1, 1024, 1, 1] | [1, 1024, 1, 1] |    0    |   2048  |
|        linear_0       |    [1, 1024]    |    [1, 1000]    | 1025000 | 1024000 |
+-----------------------+-----------------+-----------------+---------+---------+
Total Flops: 12817920     Total Params: 4253864

评估精度 

model.evaluate(test_dataset, batch_size=128, verbose=1)

剪裁前的精度为:0.937 

{'loss': [1.3092904],  'acc': 0.93706}

剪裁 。对网络模型两个不同的网络层按照参数名分别进行比例为50%,60%的裁剪。

pruner = L1NormFilterPruner(net, [1, 3, 32, 32])
pruner.prune_vars({'conv2d_22.w_0':0.5, 'conv2d_20.w_0':0.6}, axis=0)

 计算剪裁之后的flops

flops = paddle.flops(net, input_size=[1, 3, 32, 32], print_detail=True)

这里给出部分结果,剪裁的conv2d_20和conv2d_22比例为0.5,0.6

                        params/flops

conv2d_20        262144/1048576 ---->104960/419840

conv2d_22        262144/1048576 -----> 52480/209920

total                4253864/12817920 ------> 3615301/11067556 

+-----------------------+-----------------+-----------------+---------+---------+
|       Layer Name      |   Input Shape   |   Output Shape  |  Params |  Flops  |
+-----------------------+-----------------+-----------------+---------+---------+

|       conv2d_20       |  [1, 512, 2, 2] |  [1, 205, 2, 2] |  104960 |  419840 |

|       conv2d_22       |  [1, 205, 2, 2] |  [1, 256, 2, 2] |  52480  |  209920 |

+-----------------------+-----------------+-----------------+---------+---------+
Total Flops: 11067556     Total Params: 3615301

剪裁后的精度

model.evaluate(test_dataset, batch_size=128, verbose=1)

精度由0.93降为0.76 。对模型进行裁剪会导致模型精度有一定程度下降。

{'loss': [2.2277398], 'acc_top5': 0.76516}

 对模型进行微调会有助于模型恢复原有精度。 以下代码对裁剪过后的模型进行评估后执行了一个epoch的微调,再对微调过后的模型重新进行评估: 

optimizer = paddle.optimizer.Momentum(
    learning_rate=0.1,
    parameters=net.parameters())

model.prepare(
    optimizer,
    paddle.nn.CrossEntropyLoss(),
    paddle.metric.Accuracy(topk=(1, 5)))

model.fit(train_dataset, epochs=1, batch_size=128, verbose=1)

评估

model.evaluate(test_dataset, batch_size=128, verbose=1)

 微调后精度恢复到0.94,比剪裁prune前还要高是因为剪裁前模型没有调到最优。

{'loss': [1.2696353], 'acc_top1': 0.54044, 'acc_top5': 0.94396}

 2.2.1剪裁前后对比

剪裁前 剪裁后
params 4253864  3615301
flops 12817920 11067556
accuracy 0.937 0.94
剪裁前的评估每个样本耗时:1s/step 
剪裁后的评估每个样本耗时:905ms/step

你可能感兴趣的:(paddlepaddle,人工智能)