目录
1.简介
2.卷积Filter剪裁
2.1paddleSlim的API
2.2示例
2.2.1剪裁前后对比
PaddleSlim是一个专注于深度学习模型压缩的工具库,提供剪裁、量化、蒸馏、和模型结构搜索等模型压缩策略,帮助用户快速实现模型的小型化。
对卷积网络的通道进行一次剪裁。剪裁一个卷积层的通道,是指剪裁该卷积层输出的通道。卷积层的权重形状为 [output_channel, input_channel, kernel_size, kernel_size]
,通过剪裁该权重的第一纬度达到剪裁输出通道数的目的。
实际剪裁时要考虑到每层通道的敏感度,一般剪裁后要在验证集上测试精度得到敏感度,敏感度低的剪裁掉来压缩模型。
在paddleslim中关于动态图的剪裁接口主要有三个:
Filters
的 l1-norm
统计值对单个卷积层内的 Filters
的重要性进行排序,并按指定比例剪裁掉相对不重要的 Filters
。对 Filters
的剪裁等价于剪裁卷积层的输出通道数。Filters
的 l2-norm
统计值对单个卷积层内的 Filters
的重要性进行排序,并按指定比例剪裁掉相对不重要的 Filters
。对 Filters
的剪裁等价于剪裁卷积层的输出通道数。Filters
的重要性进行排序,并按指定比例剪裁掉相对不重要的 Filters
。对 Filters
的剪裁等价于剪裁卷积层的输出通道数。导入各种包
import paddle
import paddle.vision.models as models
from paddle.static import InputSpec as Input
from paddle.vision.datasets import Cifar10
import paddle.vision.transforms as T
from paddleslim.dygraph import L1NormFilterPruner
网络定义和数据集加载
net = models.mobilenet_v1()
inputs = Input(shape=[None, 3, 32, 32], dtype='float32', name='image')
labels = Input(shape=[None, 1], dtype='int64', name='label')
optmizer = paddle.optimizer.Momentum(learning_rate=0.1, parameters=net.parameters())
model = paddle.Model(net, inputs, labels)
model.prepare(
optimizer=optmizer,
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy(topk=(1, 5))
)
transforms = T.Compose([
T.Transpose(),
T.Normalize([127.5], [127.5])
])
train_dataset = Cifar10(mode='train', transform=transforms)
test_dataset = Cifar10(mode='train', transform=transforms)
训练模型
model.fit(train_dataset, epochs=2, batch_size=128, verbose=1)
计算剪裁之前的模型相关信息。使用paddle.flops函数
flops = paddle.flops(net, input_size=[1, 3, 32, 32], print_detail=True)
结果会返回模型的详细参数量Params和Flops,此外在Layer Name中的名字就是调用剪裁器pruner要给出剪裁参数名。
+-----------------------+-----------------+-----------------+---------+---------+
| Layer Name | Input Shape | Output Shape | Params | Flops |
+-----------------------+-----------------+-----------------+---------+---------+
| conv2d_0 | [1, 3, 32, 32] | [1, 32, 16, 16] | 864 | 221184 |
| batch_norm2d_0 | [1, 32, 16, 16] | [1, 32, 16, 16] | 128 | 16384 |
| re_lu_0 | [1, 32, 16, 16] | [1, 32, 16, 16] | 0 | 0 |
| conv2d_1 | [1, 32, 16, 16] | [1, 32, 16, 16] | 288 | 73728 |
| batch_norm2d_1 | [1, 32, 16, 16] | [1, 32, 16, 16] | 128 | 16384 |
| re_lu_1 | [1, 32, 16, 16] | [1, 32, 16, 16] | 0 | 0 |
| conv2d_2 | [1, 32, 16, 16] | [1, 64, 16, 16] | 2048 | 524288 |
| batch_norm2d_2 | [1, 64, 16, 16] | [1, 64, 16, 16] | 256 | 32768 |
| re_lu_2 | [1, 64, 16, 16] | [1, 64, 16, 16] | 0 | 0 |
| conv2d_3 | [1, 64, 16, 16] | [1, 64, 8, 8] | 576 | 36864 |
| batch_norm2d_3 | [1, 64, 8, 8] | [1, 64, 8, 8] | 256 | 8192 |
| re_lu_3 | [1, 64, 8, 8] | [1, 64, 8, 8] | 0 | 0 |
| conv2d_4 | [1, 64, 8, 8] | [1, 128, 8, 8] | 8192 | 524288 |
| batch_norm2d_4 | [1, 128, 8, 8] | [1, 128, 8, 8] | 512 | 16384 |
| re_lu_4 | [1, 128, 8, 8] | [1, 128, 8, 8] | 0 | 0 |
| conv2d_5 | [1, 128, 8, 8] | [1, 128, 8, 8] | 1152 | 73728 |
| batch_norm2d_5 | [1, 128, 8, 8] | [1, 128, 8, 8] | 512 | 16384 |
| re_lu_5 | [1, 128, 8, 8] | [1, 128, 8, 8] | 0 | 0 |
| conv2d_6 | [1, 128, 8, 8] | [1, 128, 8, 8] | 16384 | 1048576 |
| batch_norm2d_6 | [1, 128, 8, 8] | [1, 128, 8, 8] | 512 | 16384 |
| re_lu_6 | [1, 128, 8, 8] | [1, 128, 8, 8] | 0 | 0 |
| conv2d_7 | [1, 128, 8, 8] | [1, 128, 4, 4] | 1152 | 18432 |
| batch_norm2d_7 | [1, 128, 4, 4] | [1, 128, 4, 4] | 512 | 4096 |
| re_lu_7 | [1, 128, 4, 4] | [1, 128, 4, 4] | 0 | 0 |
| conv2d_8 | [1, 128, 4, 4] | [1, 256, 4, 4] | 32768 | 524288 |
| batch_norm2d_8 | [1, 256, 4, 4] | [1, 256, 4, 4] | 1024 | 8192 |
| re_lu_8 | [1, 256, 4, 4] | [1, 256, 4, 4] | 0 | 0 |
| conv2d_9 | [1, 256, 4, 4] | [1, 256, 4, 4] | 2304 | 36864 |
| batch_norm2d_9 | [1, 256, 4, 4] | [1, 256, 4, 4] | 1024 | 8192 |
| re_lu_9 | [1, 256, 4, 4] | [1, 256, 4, 4] | 0 | 0 |
| conv2d_10 | [1, 256, 4, 4] | [1, 256, 4, 4] | 65536 | 1048576 |
| batch_norm2d_10 | [1, 256, 4, 4] | [1, 256, 4, 4] | 1024 | 8192 |
| re_lu_10 | [1, 256, 4, 4] | [1, 256, 4, 4] | 0 | 0 |
| conv2d_11 | [1, 256, 4, 4] | [1, 256, 2, 2] | 2304 | 9216 |
| batch_norm2d_11 | [1, 256, 2, 2] | [1, 256, 2, 2] | 1024 | 2048 |
| re_lu_11 | [1, 256, 2, 2] | [1, 256, 2, 2] | 0 | 0 |
| conv2d_12 | [1, 256, 2, 2] | [1, 512, 2, 2] | 131072 | 524288 |
| batch_norm2d_12 | [1, 512, 2, 2] | [1, 512, 2, 2] | 2048 | 4096 |
| re_lu_12 | [1, 512, 2, 2] | [1, 512, 2, 2] | 0 | 0 |
| conv2d_13 | [1, 512, 2, 2] | [1, 512, 2, 2] | 4608 | 18432 |
| batch_norm2d_13 | [1, 512, 2, 2] | [1, 512, 2, 2] | 2048 | 4096 |
| re_lu_13 | [1, 512, 2, 2] | [1, 512, 2, 2] | 0 | 0 |
| conv2d_14 | [1, 512, 2, 2] | [1, 512, 2, 2] | 262144 | 1048576 |
| batch_norm2d_14 | [1, 512, 2, 2] | [1, 512, 2, 2] | 2048 | 4096 |
| re_lu_14 | [1, 512, 2, 2] | [1, 512, 2, 2] | 0 | 0 |
| conv2d_15 | [1, 512, 2, 2] | [1, 512, 2, 2] | 4608 | 18432 |
| batch_norm2d_15 | [1, 512, 2, 2] | [1, 512, 2, 2] | 2048 | 4096 |
| re_lu_15 | [1, 512, 2, 2] | [1, 512, 2, 2] | 0 | 0 |
| conv2d_16 | [1, 512, 2, 2] | [1, 512, 2, 2] | 262144 | 1048576 |
| batch_norm2d_16 | [1, 512, 2, 2] | [1, 512, 2, 2] | 2048 | 4096 |
| re_lu_16 | [1, 512, 2, 2] | [1, 512, 2, 2] | 0 | 0 |
| conv2d_17 | [1, 512, 2, 2] | [1, 512, 2, 2] | 4608 | 18432 |
| batch_norm2d_17 | [1, 512, 2, 2] | [1, 512, 2, 2] | 2048 | 4096 |
| re_lu_17 | [1, 512, 2, 2] | [1, 512, 2, 2] | 0 | 0 |
| conv2d_18 | [1, 512, 2, 2] | [1, 512, 2, 2] | 262144 | 1048576 |
| batch_norm2d_18 | [1, 512, 2, 2] | [1, 512, 2, 2] | 2048 | 4096 |
| re_lu_18 | [1, 512, 2, 2] | [1, 512, 2, 2] | 0 | 0 |
| conv2d_19 | [1, 512, 2, 2] | [1, 512, 2, 2] | 4608 | 18432 |
| batch_norm2d_19 | [1, 512, 2, 2] | [1, 512, 2, 2] | 2048 | 4096 |
| re_lu_19 | [1, 512, 2, 2] | [1, 512, 2, 2] | 0 | 0 |
| conv2d_20 | [1, 512, 2, 2] | [1, 512, 2, 2] | 262144 | 1048576 |
| batch_norm2d_20 | [1, 512, 2, 2] | [1, 512, 2, 2] | 2048 | 4096 |
| re_lu_20 | [1, 512, 2, 2] | [1, 512, 2, 2] | 0 | 0 |
| conv2d_21 | [1, 512, 2, 2] | [1, 512, 2, 2] | 4608 | 18432 |
| batch_norm2d_21 | [1, 512, 2, 2] | [1, 512, 2, 2] | 2048 | 4096 |
| re_lu_21 | [1, 512, 2, 2] | [1, 512, 2, 2] | 0 | 0 |
| conv2d_22 | [1, 512, 2, 2] | [1, 512, 2, 2] | 262144 | 1048576 |
| batch_norm2d_22 | [1, 512, 2, 2] | [1, 512, 2, 2] | 2048 | 4096 |
| re_lu_22 | [1, 512, 2, 2] | [1, 512, 2, 2] | 0 | 0 |
| conv2d_23 | [1, 512, 2, 2] | [1, 512, 1, 1] | 4608 | 4608 |
| batch_norm2d_23 | [1, 512, 1, 1] | [1, 512, 1, 1] | 2048 | 1024 |
| re_lu_23 | [1, 512, 1, 1] | [1, 512, 1, 1] | 0 | 0 |
| conv2d_24 | [1, 512, 1, 1] | [1, 1024, 1, 1] | 524288 | 524288 |
| batch_norm2d_24 | [1, 1024, 1, 1] | [1, 1024, 1, 1] | 4096 | 2048 |
| re_lu_24 | [1, 1024, 1, 1] | [1, 1024, 1, 1] | 0 | 0 |
| conv2d_25 | [1, 1024, 1, 1] | [1, 1024, 1, 1] | 9216 | 9216 |
| batch_norm2d_25 | [1, 1024, 1, 1] | [1, 1024, 1, 1] | 4096 | 2048 |
| re_lu_25 | [1, 1024, 1, 1] | [1, 1024, 1, 1] | 0 | 0 |
| conv2d_26 | [1, 1024, 1, 1] | [1, 1024, 1, 1] | 1048576 | 1048576 |
| batch_norm2d_26 | [1, 1024, 1, 1] | [1, 1024, 1, 1] | 4096 | 2048 |
| re_lu_26 | [1, 1024, 1, 1] | [1, 1024, 1, 1] | 0 | 0 |
| adaptive_avg_pool2d_0 | [1, 1024, 1, 1] | [1, 1024, 1, 1] | 0 | 2048 |
| linear_0 | [1, 1024] | [1, 1000] | 1025000 | 1024000 |
+-----------------------+-----------------+-----------------+---------+---------+
Total Flops: 12817920 Total Params: 4253864
评估精度
model.evaluate(test_dataset, batch_size=128, verbose=1)
剪裁前的精度为:0.937
{'loss': [1.3092904], 'acc': 0.93706}
剪裁 。对网络模型两个不同的网络层按照参数名分别进行比例为50%,60%的裁剪。
pruner = L1NormFilterPruner(net, [1, 3, 32, 32])
pruner.prune_vars({'conv2d_22.w_0':0.5, 'conv2d_20.w_0':0.6}, axis=0)
计算剪裁之后的flops
flops = paddle.flops(net, input_size=[1, 3, 32, 32], print_detail=True)
这里给出部分结果,剪裁的conv2d_20和conv2d_22比例为0.5,0.6
params/flops
conv2d_20 262144/1048576 ---->104960/419840
conv2d_22 262144/1048576 -----> 52480/209920
total 4253864/12817920 ------> 3615301/11067556
+-----------------------+-----------------+-----------------+---------+---------+
| Layer Name | Input Shape | Output Shape | Params | Flops |
+-----------------------+-----------------+-----------------+---------+---------+
| conv2d_20 | [1, 512, 2, 2] | [1, 205, 2, 2] | 104960 | 419840 |
| conv2d_22 | [1, 205, 2, 2] | [1, 256, 2, 2] | 52480 | 209920 |
+-----------------------+-----------------+-----------------+---------+---------+
Total Flops: 11067556 Total Params: 3615301
剪裁后的精度
model.evaluate(test_dataset, batch_size=128, verbose=1)
精度由0.93降为0.76 。对模型进行裁剪会导致模型精度有一定程度下降。
{'loss': [2.2277398], 'acc_top5': 0.76516}
对模型进行微调会有助于模型恢复原有精度。 以下代码对裁剪过后的模型进行评估后执行了一个epoch
的微调,再对微调过后的模型重新进行评估:
optimizer = paddle.optimizer.Momentum(
learning_rate=0.1,
parameters=net.parameters())
model.prepare(
optimizer,
paddle.nn.CrossEntropyLoss(),
paddle.metric.Accuracy(topk=(1, 5)))
model.fit(train_dataset, epochs=1, batch_size=128, verbose=1)
评估
model.evaluate(test_dataset, batch_size=128, verbose=1)
微调后精度恢复到0.94,比剪裁prune前还要高是因为剪裁前模型没有调到最优。
{'loss': [1.2696353], 'acc_top1': 0.54044, 'acc_top5': 0.94396}
剪裁前 | 剪裁后 | |
---|---|---|
params | 4253864 | 3615301 |
flops | 12817920 | 11067556 |
accuracy | 0.937 | 0.94 |
剪裁前的评估每个样本耗时:1s/step 剪裁后的评估每个样本耗时:905ms/step