使用nni对ResNet进行模型剪枝

NNI (Neural Network Intelligence) 是微软开源的自动机器学习(AutoML)工具包,对机器学习生命周期的各个环节做了更加全面的支持,包括特征工程神经网络架构搜索(NAS)超参调优模型压缩(剪枝、量化)等。

网上关于nni的教程比较少,都是介绍nni的,但实际例程暂时没有看到,官网给出的example中的剪枝代码也存在某些错误。这里给出一种通用的通过nni进行模型剪枝的示例代码,可以对ResNet、VGG等进行自动剪枝。

安装 NNI

  • 官方 github: https://github.com/microsoft/nni
  • 原生中文文档 点击查看教程
  • 支持 Windows 和 Linux
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple nni

所支持的模型剪枝算法

名称 算法简介
Level Pruner 根据权重的绝对值,来按比例修剪权重。
AGP Pruner 自动的逐步剪枝(是否剪枝的判断:基于对模型剪枝的效果)https://arxiv.org/abs/1710.01878
Lottery Ticket Pruner “The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks” 提出的剪枝过程,会反复修剪模型。 https://arxiv.org/abs/1803.03635
FPGM Pruner Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration https://arxiv.org/pdf/1811.00250.pdf
L1Filter Pruner 在卷积层中具有最小 L1 权重规范的剪枝滤波器(用于 Efficient Convnets 的剪枝滤波器) https://arxiv.org/abs/1608.08710
L2Filter Pruner 在卷积层中具有最小 L2 权重规范的剪枝滤波器
ActivationAPoZRankFilterPruner 基于指标 APoZ(平均百分比零)的剪枝滤波器,该指标测量(卷积)图层激活中零的百分比。 https://arxiv.org/abs/1607.03250
ActivationMeanRankFilterPruner 基于计算输出激活最小平均值指标的剪枝滤波器
Slim Pruner 通过修剪 BN 层中的缩放因子来修剪卷积层中的通道 (Learning Efficient Convolutional Networks through Network Slimming) https://arxiv.org/abs/1708.06519
TaylorFO Pruner 基于一阶泰勒展开的权重对滤波器剪枝 (Importance Estimation for Neural Network Pruning) http://jankautz.com/publications/Importance4NNPruning_CVPR19.pdf
ADMM Pruner 基于 ADMM 优化技术的剪枝 https://arxiv.org/abs/1804.03294
NetAdapt Pruner 在满足计算资源预算的情况下,对预训练的网络迭代剪枝 https://arxiv.org/abs/1804.03230
SimulatedAnnealing Pruner 通过启发式的模拟退火算法进行自动剪枝 https://arxiv.org/abs/1907.03141
AutoCompress Pruner 通过迭代调用 SimulatedAnnealing Pruner 和 ADMM Pruner 进行自动剪枝 https://arxiv.org/abs/1907.03141
AMC Pruner AMC:移动设备的模型压缩和加速 https://arxiv.org/pdf/1802.03494.pdf

剪枝 ResNet 18

构建模型

为了自适应输入图像的尺寸,将 ResNet 的平均池化层改为 全局平均池化

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64
        # this layer is different from torchvision.resnet18() since this model adopted for Cifar10
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.adaptiveAvgPool = nn.AdaptiveAvgPool2d([1, 1])
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.adaptiveAvgPool(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

模型参数评估

input_size = [1, 3, 640, 640]
dummy_input = torch.randn(input_size).to(device)
flops, params, results = count_flops_params(model, dummy_input)
print(f"Model FLOPs {flops/1e6:.2f}M, Params {params/1e6:.2f}M")

输出模型评估结果:

+-------+---------------------+--------+------------------+-----------+---------+
| Index |         Name        |  Type  |   Weight Shape   |   FLOPs   | #Params |
+-------+---------------------+--------+------------------+-----------+---------+
|   0   |        conv1        | Conv2d |  (64, 3, 3, 3)   |  7077888  |   1728  |
|   1   |    layer1.0.conv1   | Conv2d |  (64, 64, 3, 3)  | 150994944 |  36864  |
|   2   |    layer1.0.conv2   | Conv2d |  (64, 64, 3, 3)  | 150994944 |  36864  |
|   3   |    layer1.1.conv1   | Conv2d |  (64, 64, 3, 3)  | 150994944 |  36864  |
|   4   |    layer1.1.conv2   | Conv2d |  (64, 64, 3, 3)  | 150994944 |  36864  |
|   5   |    layer2.0.conv1   | Conv2d | (128, 64, 3, 3)  |  75497472 |  73728  |
|   6   |    layer2.0.conv2   | Conv2d | (128, 128, 3, 3) | 150994944 |  147456 |
|   7   | layer2.0.shortcut.0 | Conv2d | (128, 64, 1, 1)  |  8388608  |   8192  |
|   8   |    layer2.1.conv1   | Conv2d | (128, 128, 3, 3) | 150994944 |  147456 |
|   9   |    layer2.1.conv2   | Conv2d | (128, 128, 3, 3) | 150994944 |  147456 |
|   10  |    layer3.0.conv1   | Conv2d | (256, 128, 3, 3) |  75497472 |  294912 |
|   11  |    layer3.0.conv2   | Conv2d | (256, 256, 3, 3) | 150994944 |  589824 |
|   12  | layer3.0.shortcut.0 | Conv2d | (256, 128, 1, 1) |  8388608  |  32768  |
|   13  |    layer3.1.conv1   | Conv2d | (256, 256, 3, 3) | 150994944 |  589824 |
|   14  |    layer3.1.conv2   | Conv2d | (256, 256, 3, 3) | 150994944 |  589824 |
|   15  |    layer4.0.conv1   | Conv2d | (512, 256, 3, 3) |  75497472 | 1179648 |
|   16  |    layer4.0.conv2   | Conv2d | (512, 512, 3, 3) | 150994944 | 2359296 |
|   17  | layer4.0.shortcut.0 | Conv2d | (512, 256, 1, 1) |  8388608  |  131072 |
|   18  |    layer4.1.conv1   | Conv2d | (512, 512, 3, 3) | 150994944 | 2359296 |
|   19  |    layer4.1.conv2   | Conv2d | (512, 512, 3, 3) | 150994944 | 2359296 |
|   20  |        linear       | Linear |    (10, 512)     |    5120   |   5130  |
+-------+---------------------+--------+------------------+-----------+---------+
Model FLOPs 27215.47M, Params 11.16M, Time 710.94 ms

剪枝

  • 此处使用 L-1 剪枝算法,以稀疏度0.8进行剪枝
print('\nstart pre-training ...')
config_list = [{'sparsity': 0.8, 'op_types': ['Conv2d']}]
pruner = L1FilterPruner(model, config_list)
pruner.compress()
pruner.export_model(model_path=pruned_model_path, mask_path=pruned_model_mask_path)
pruner._unwrap_model()
m_speedup = ModelSpeedup(model, dummy_input, pruned_model_mask_path, device)
m_speedup.speedup_model()
  • 剪枝后需要对模型进行 微调,以恢复精度
print('\nstart finetuning ...')
    best_acc = 0
    for epoch in range(2):
        train(epoch)
        scheduler.step()
        acc = test()
        if acc > best_acc:
            best_acc = acc
            state_dict = model.state_dict()
save_model(state_dict, speedup_model_path)
  • 剪枝后的模型结构如下所示:
+-------+---------------------+--------+------------------+----------+---------+
| Index |         Name        |  Type  |   Weight Shape   |  FLOPs   | #Params |
+-------+---------------------+--------+------------------+----------+---------+
|   0   |        conv1        | Conv2d |  (31, 3, 3, 3)   | 3428352  |   837   |
|   1   |    layer1.0.conv1   | Conv2d |  (13, 31, 3, 3)  | 14856192 |   3627  |
|   2   |    layer1.0.conv2   | Conv2d |  (31, 13, 3, 3)  | 14856192 |   3627  |
|   3   |    layer1.1.conv1   | Conv2d |  (13, 31, 3, 3)  | 14856192 |   3627  |
|   4   |    layer1.1.conv2   | Conv2d |  (31, 13, 3, 3)  | 14856192 |   3627  |
|   5   |    layer2.0.conv1   | Conv2d |  (26, 31, 3, 3)  | 7428096  |   7254  |
|   6   |    layer2.0.conv2   | Conv2d |  (63, 26, 3, 3)  | 15095808 |  14742  |
|   7   | layer2.0.shortcut.0 | Conv2d |  (63, 31, 1, 1)  | 1999872  |   1953  |
|   8   |    layer2.1.conv1   | Conv2d |  (26, 63, 3, 3)  | 15095808 |  14742  |
|   9   |    layer2.1.conv2   | Conv2d |  (63, 26, 3, 3)  | 15095808 |  14742  |
|   10  |    layer3.0.conv1   | Conv2d |  (52, 63, 3, 3)  | 7547904  |  29484  |
|   11  |    layer3.0.conv2   | Conv2d | (132, 52, 3, 3)  | 15814656 |  61776  |
|   12  | layer3.0.shortcut.0 | Conv2d | (132, 63, 1, 1)  | 2128896  |   8316  |
|   13  |    layer3.1.conv1   | Conv2d | (52, 132, 3, 3)  | 15814656 |  61776  |
|   14  |    layer3.1.conv2   | Conv2d | (132, 52, 3, 3)  | 15814656 |  61776  |
|   15  |    layer4.0.conv1   | Conv2d | (103, 132, 3, 3) | 7831296  |  122364 |
|   16  |    layer4.0.conv2   | Conv2d | (254, 103, 3, 3) | 15069312 |  235458 |
|   17  | layer4.0.shortcut.0 | Conv2d | (254, 132, 1, 1) | 2145792  |  33528  |
|   18  |    layer4.1.conv1   | Conv2d | (103, 254, 3, 3) | 15069312 |  235458 |
|   19  |    layer4.1.conv2   | Conv2d | (254, 103, 3, 3) | 15069312 |  235458 |
|   20  |        linear       | Linear |    (10, 254)     |   2540   |   2550  |
+-------+---------------------+--------+------------------+----------+---------+
Model FLOPs 2571.38M, Params 1.11M, Time 201.12 ms

可以看到,经过0.8的稀疏度剪枝后,模型参数量从11.16M缩减为1.16M,平均CPU运行时间从700多ms下降到200多ms,大幅度提升了模型的实时性能。点击查看完整的示例代码

你可能感兴趣的:(深度学习,NNI,ResNet,模型剪枝)