论文复现:Learning both Weights and Connections for Efficient Neural Networks

论文核心

论文提出了非结构化剪枝策略,针对卷积层权重进行剪枝,并提出了著名的三步走剪枝策略。判断权重重要性的方式是对权重进行 L 1 L1 L1 L 2 L2 L2正则化,然后按照一定的剪枝比例使正则化值较小的权重为0。
论文复现:Learning both Weights and Connections for Efficient Neural Networks_第1张图片

论文细节品读

模型压缩意义:论文从功耗方面讨论了模型需要被压缩的原因,神经网络越大,参数量和计算量越大,导致模型在前向推理时功耗越大,这对移动端和嵌入式端是残酷的。并且大模型无法存储在SRAM 中(由于工艺原因,SRAM 一般只有几M 到几十M ),只能存储在DRAM 中,而从DRAM 中读取模型数据时,功耗会是SRAM 的100多倍。
论文复现:Learning both Weights and Connections for Efficient Neural Networks_第2张图片
正则化方式:论文提出了 L 1 L1 L1 L 2 L2 L2正则化两种方式来判断权重重要性,经过实验, L 2 L2 L2正则化效果更好。

神经元剪枝:经过权重修剪后,修剪部分权重为0,因此可将对应的神经元、梯度同时剪枝(置零)。论文有提到这一点,但是看论文作者提供的源码,没有进行这一步操作(也许是代码版本原因,该部分可自行拓展)。

论文复现

一:选取数据集、模型、优化器等
本人在此选择 cifar10 数据集,模型选择 resnet18 ,其他细节略过。(作者源码都有此部分的配置,可阅读源码)

二:模型训练
在作者论文基础上添加了可视化部分代码

def paint_vinz(epoch):
    global test_loss, train_loss, train_top1, test_top1, train_error_history, test_error_history

    viz.line(X=epoch_p,
             Y=np.column_stack((np.array(train_error_history), np.array(test_error_history))),
             win=line, opts=dict(legend=['train_error', 'test_error']))

    # visdom text 支持html语句
    viz.text(
        "

epoch: {}


"
"

train_loss: {:.4f}


"
"

test_loss: {:.4f}


"
"

train_top1: {:.4f}


"
"

test_top1: {:.4f}


"
"

Time: {}

"
.format( epoch+1, train_loss, test_loss, train_top1, test_top1, time.strftime('%H:%M:%S', time.gmtime(time.time() - start_time))), win=text)

论文复现:Learning both Weights and Connections for Efficient Neural Networks_第3张图片
三:未剪枝模型params & FLOPs 计算
这里使用的 torhcstat 包用来统计模型相关参数,贴上个人代码。论文作者在源码中统计的不是模型的参数量,而是非零参数,这一点后面讨论。

# 注意该部分代码也要配合论文作者源码,才能正常使用
import os
import torch
from utils import*
from models import get_model
import argparse
from torchstat import stat

parser = argparse.ArgumentParser(description="model params counting test")
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')

args = parser.parse_args()

model = get_model('resnet18')

if args.resume:
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['net'], strict=False)
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))

stat(model, (3, 32, 32))

```python
                          module name  input shape output shape      params memory(MB)             MAdd          Flops  MemRead(B)  MemWrite(B) duration[%]   MemR+W(B)
0                   conv_bn_relu.conv    3  32  32   64  32  32      1728.0       0.25      3,473,408.0    1,769,472.0     19200.0     262144.0       0.00%    281344.0
1                     conv_bn_relu.bn   64  32  32   64  32  32       128.0       0.25        262,144.0      131,072.0    262656.0     262144.0       0.00%    524800.0
2                   conv_bn_relu.relu   64  32  32   64  32  32         0.0       0.25         65,536.0       65,536.0    262144.0     262144.0       0.00%    524288.0
3                 layer1.0.conv1.conv   64  32  32   64  32  32     36864.0       0.25     75,431,936.0   37,748,736.0    409600.0     262144.0       0.00%    671744.0
4                   layer1.0.conv1.bn   64  32  32   64  32  32       128.0       0.25        262,144.0      131,072.0    262656.0     262144.0       0.00%    524800.0
5                 layer1.0.conv1.relu   64  32  32   64  32  32         0.0       0.25         65,536.0       65,536.0    262144.0     262144.0       0.00%    524288.0
6                 layer1.0.conv2.conv   64  32  32   64  32  32     36864.0       0.25     75,431,936.0   37,748,736.0    409600.0     262144.0       0.00%    671744.0
7                   layer1.0.conv2.bn   64  32  32   64  32  32       128.0       0.25        262,144.0      131,072.0    262656.0     262144.0       0.00%    524800.0
8                 layer1.0.conv2.relu   64  32  32   64  32  32         0.0       0.25              0.0            0.0         0.0          0.0       0.00%         0.0
9                   layer1.0.shortcut   64  32  32   64  32  32         0.0       0.25              0.0            0.0         0.0          0.0       0.00%         0.0
10                layer1.1.conv1.conv   64  32  32   64  32  32     36864.0       0.25     75,431,936.0   37,748,736.0    409600.0     262144.0       0.00%    671744.0
11                  layer1.1.conv1.bn   64  32  32   64  32  32       128.0       0.25        262,144.0      131,072.0    262656.0     262144.0       0.00%    524800.0
12                layer1.1.conv1.relu   64  32  32   64  32  32         0.0       0.25         65,536.0       65,536.0    262144.0     262144.0       0.00%    524288.0
13                layer1.1.conv2.conv   64  32  32   64  32  32     36864.0       0.25     75,431,936.0   37,748,736.0    409600.0     262144.0       0.00%    671744.0
14                  layer1.1.conv2.bn   64  32  32   64  32  32       128.0       0.25        262,144.0      131,072.0    262656.0     262144.0       0.00%    524800.0
15                layer1.1.conv2.relu   64  32  32   64  32  32         0.0       0.25              0.0            0.0         0.0          0.0       0.00%         0.0
16                  layer1.1.shortcut   64  32  32   64  32  32         0.0       0.25              0.0            0.0         0.0          0.0       0.00%         0.0
17                layer2.0.conv1.conv   64  32  32  128  16  16     73728.0       0.12     37,715,968.0   18,874,368.0    557056.0     131072.0       0.00%    688128.0
18                  layer2.0.conv1.bn  128  16  16  128  16  16       256.0       0.12        131,072.0       65,536.0    132096.0     131072.0       0.00%    263168.0
19                layer2.0.conv1.relu  128  16  16  128  16  16         0.0       0.12         32,768.0       32,768.0    131072.0     131072.0       0.00%    262144.0
20                layer2.0.conv2.conv  128  16  16  128  16  16    147456.0       0.12     75,464,704.0   37,748,736.0    720896.0     131072.0       0.00%    851968.0
21                  layer2.0.conv2.bn  128  16  16  128  16  16       256.0       0.12        131,072.0       65,536.0    132096.0     131072.0       0.00%    263168.0
22                layer2.0.conv2.relu  128  16  16  128  16  16         0.0       0.12              0.0            0.0         0.0          0.0       0.00%         0.0
23     layer2.0.shortcut.conv_bn.conv   64  32  32  128  16  16      8192.0       0.12      4,161,536.0    2,097,152.0    294912.0     131072.0       0.00%    425984.0
24       layer2.0.shortcut.conv_bn.bn  128  16  16  128  16  16       256.0       0.12        131,072.0       65,536.0    132096.0     131072.0       0.00%    263168.0
25     layer2.0.shortcut.conv_bn.relu  128  16  16  128  16  16         0.0       0.12              0.0            0.0         0.0          0.0       0.00%         0.0
26                layer2.1.conv1.conv  128  16  16  128  16  16    147456.0       0.12     75,464,704.0   37,748,736.0    720896.0     131072.0       0.00%    851968.0
27                  layer2.1.conv1.bn  128  16  16  128  16  16       256.0       0.12        131,072.0       65,536.0    132096.0     131072.0       0.00%    263168.0
28                layer2.1.conv1.relu  128  16  16  128  16  16         0.0       0.12         32,768.0       32,768.0    131072.0     131072.0       0.00%    262144.0
29                layer2.1.conv2.conv  128  16  16  128  16  16    147456.0       0.12     75,464,704.0   37,748,736.0    720896.0     131072.0       0.00%    851968.0
30                  layer2.1.conv2.bn  128  16  16  128  16  16       256.0       0.12        131,072.0       65,536.0    132096.0     131072.0       0.00%    263168.0
31                layer2.1.conv2.relu  128  16  16  128  16  16         0.0       0.12              0.0            0.0         0.0          0.0       0.00%         0.0
32                  layer2.1.shortcut  128  16  16  128  16  16         0.0       0.12              0.0            0.0         0.0          0.0       0.00%         0.0
33                layer3.0.conv1.conv  128  16  16  256   8   8    294912.0       0.06     37,732,352.0   18,874,368.0   1310720.0      65536.0       0.00%   1376256.0
34                  layer3.0.conv1.bn  256   8   8  256   8   8       512.0       0.06         65,536.0       32,768.0     67584.0      65536.0       0.00%    133120.0
35                layer3.0.conv1.relu  256   8   8  256   8   8         0.0       0.06         16,384.0       16,384.0     65536.0      65536.0       0.00%    131072.0
36                layer3.0.conv2.conv  256   8   8  256   8   8    589824.0       0.06     75,481,088.0   37,748,736.0   2424832.0      65536.0       0.00%   2490368.0
37                  layer3.0.conv2.bn  256   8   8  256   8   8       512.0       0.06         65,536.0       32,768.0     67584.0      65536.0       0.00%    133120.0
38                layer3.0.conv2.relu  256   8   8  256   8   8         0.0       0.06              0.0            0.0         0.0          0.0       0.00%         0.0
39     layer3.0.shortcut.conv_bn.conv  128  16  16  256   8   8     32768.0       0.06      4,177,920.0    2,097,152.0    262144.0      65536.0       0.00%    327680.0
40       layer3.0.shortcut.conv_bn.bn  256   8   8  256   8   8       512.0       0.06         65,536.0       32,768.0     67584.0      65536.0       0.00%    133120.0
41     layer3.0.shortcut.conv_bn.relu  256   8   8  256   8   8         0.0       0.06              0.0            0.0         0.0          0.0       0.00%         0.0
42                layer3.1.conv1.conv  256   8   8  256   8   8    589824.0       0.06     75,481,088.0   37,748,736.0   2424832.0      65536.0       0.00%   2490368.0
43                  layer3.1.conv1.bn  256   8   8  256   8   8       512.0       0.06         65,536.0       32,768.0     67584.0      65536.0       0.00%    133120.0
44                layer3.1.conv1.relu  256   8   8  256   8   8         0.0       0.06         16,384.0       16,384.0     65536.0      65536.0       0.00%    131072.0
45                layer3.1.conv2.conv  256   8   8  256   8   8    589824.0       0.06     75,481,088.0   37,748,736.0   2424832.0      65536.0       0.00%   2490368.0
46                  layer3.1.conv2.bn  256   8   8  256   8   8       512.0       0.06         65,536.0       32,768.0     67584.0      65536.0       0.00%    133120.0
47                layer3.1.conv2.relu  256   8   8  256   8   8         0.0       0.06              0.0            0.0         0.0          0.0       0.00%         0.0
48                  layer3.1.shortcut  256   8   8  256   8   8         0.0       0.06              0.0            0.0         0.0          0.0       0.00%         0.0
49                layer4.0.conv1.conv  256   8   8  512   4   4   1179648.0       0.03     37,740,544.0   18,874,368.0   4784128.0      32768.0       0.00%   4816896.0
50                  layer4.0.conv1.bn  512   4   4  512   4   4      1024.0       0.03         32,768.0       16,384.0     36864.0      32768.0       0.00%     69632.0
51                layer4.0.conv1.relu  512   4   4  512   4   4         0.0       0.03          8,192.0        8,192.0     32768.0      32768.0       0.00%     65536.0
52                layer4.0.conv2.conv  512   4   4  512   4   4   2359296.0       0.03     75,489,280.0   37,748,736.0   9469952.0      32768.0       0.00%   9502720.0
53                  layer4.0.conv2.bn  512   4   4  512   4   4      1024.0       0.03         32,768.0       16,384.0     36864.0      32768.0       0.00%     69632.0
54                layer4.0.conv2.relu  512   4   4  512   4   4         0.0       0.03              0.0            0.0         0.0          0.0       0.00%         0.0
55     layer4.0.shortcut.conv_bn.conv  256   8   8  512   4   4    131072.0       0.03      4,186,112.0    2,097,152.0    589824.0      32768.0       0.00%    622592.0
56       layer4.0.shortcut.conv_bn.bn  512   4   4  512   4   4      1024.0       0.03         32,768.0       16,384.0     36864.0      32768.0       0.00%     69632.0
57     layer4.0.shortcut.conv_bn.relu  512   4   4  512   4   4         0.0       0.03              0.0            0.0         0.0          0.0       0.00%         0.0
58                layer4.1.conv1.conv  512   4   4  512   4   4   2359296.0       0.03     75,489,280.0   37,748,736.0   9469952.0      32768.0       0.00%   9502720.0
59                  layer4.1.conv1.bn  512   4   4  512   4   4      1024.0       0.03         32,768.0       16,384.0     36864.0      32768.0       0.00%     69632.0
60                layer4.1.conv1.relu  512   4   4  512   4   4         0.0       0.03          8,192.0        8,192.0     32768.0      32768.0       0.00%     65536.0
61                layer4.1.conv2.conv  512   4   4  512   4   4   2359296.0       0.03     75,489,280.0   37,748,736.0   9469952.0      32768.0       0.00%   9502720.0
62                  layer4.1.conv2.bn  512   4   4  512   4   4      1024.0       0.03         32,768.0       16,384.0     36864.0      32768.0       0.00%     69632.0
63                layer4.1.conv2.relu  512   4   4  512   4   4         0.0       0.03              0.0            0.0         0.0          0.0       0.00%         0.0
64                  layer4.1.shortcut  512   4   4  512   4   4         0.0       0.03              0.0            0.0         0.0          0.0       0.00%         0.0
65                             linear          512           10      5130.0       0.00         10,230.0        5,120.0     22568.0         40.0       0.00%     22608.0
total                                                            11173962.0       7.75  1,112,999,926.0  556,962,816.0     22568.0         40.0       0.00%  57227600.0
=======================================================================================================================================================================
Total params: 11,173,962
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 7.75MB
Total MAdd: 1.11GMAdd
Total Flops: 556.96MFlops
Total MemR+W: 54.58MB

使用论文作者代码计算卷积层非零参数个数,代码如下:

# count only conv params for now
def get_no_params(net, verbose=False, mask=False):
    params = net
    tot = 0
    for p in params:
        no = torch.sum(params[p] != 0)
        if "conv" in p:
            tot += no
    return tot

实验结果如下:

tensor(11178452)

论文作者虽然只是统计了卷积层参数量,发现结果和使用 torchstat 几乎相同,这是因为 resnet 卷积层参数占绝大部分比重,可以通过上图各层参数量统计结果看出原因所在。此时还未剪枝,所有参数均不为零,所以卷积层非零参数量即为卷积层总参数量。

四:模型权重剪枝
论文作者源码提供的是使用L1正则化的策略,unstructured_prune函数,并非直接修改模型 state_dict ,而是生成 mask ,也就是说,模型经过该函数处理后,只是更新了模型的 mask 。阅读源码发现模型权重实质改变发生在在前向推理过程中。

    def unstructured_prune(self, model, prune_rate=50.0):

        # get all the prunable convolutions
        convs = model.get_prunable_layers(pruning_type=self.pruning_type)

        # collate all weights into a single vector so l1-threshold can be calculated
        all_weights = torch.Tensor()
        if torch.cuda.is_available():
            all_weights = all_weights.cuda()
        for conv in convs:
            # 将所有层的weight拼接起来
            all_weights = torch.cat((all_weights.view(-1), conv.conv.weight.view(-1)))
        abs_weights = torch.abs(all_weights.detach())
        # 找出百分位数,设置阈值
        threshold = np.percentile(abs_weights.cpu().numpy(), prune_rate)
        # prune anything beneath l1-threshold
        for conv in model.get_prunable_layers(pruning_type=self.pruning_type):
            conv.mask.update(
                # 逐乘
                torch.mul(
                    torch.gt(torch.abs(conv.conv.weight), torch.as_tensor(np.array(threshold).astype('float')).cuda()).float(),
                    conv.mask.mask.weight.cuda(),
                )
            )

剪枝完之后就是微调-finetune操作,即每修剪一次,就进行100个 batch 的训练。每次微调之后权重为0的参数会再训练,得到更新后不再为0。

def finetune(model, trainloader, criterion, optimizer, steps=100):
    # switch to train mode
    model.train()
    dataiter = iter(trainloader)
    for i in range(steps):
        try:
            input, target = dataiter.next()
        except StopIteration:
            dataiter = iter(trainloader)
            input, target = dataiter.next()

        input, target = input.to(device), target.to(device)

        # compute output
        output = model(input)
        loss = criterion(output, target)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

测试微调之后权重为0的参数变化情况:

pruner.prune(model, prune_rate)
print("训练前=================={}".format(get_no_params(model.state_dict())))
finetune(model, trainloader, criterion, optimizer, args.finetune_steps)
print("训练后=================={}".format(get_no_params(model.state_dict())))
validate(model, prune_rate, testloader, criterion, optimizer)
print("测试后=================={}".format(get_no_params(model.state_dict())))

结果如下:

训练前==================11178452
训练后==================11178452
测试后==================10062528

上面实验证明,微调之后所有零参数重新得到训练成为非零参数,模型更鲁棒。而真正产生权重剪枝效果发生在测试时前向推理的过程。

依次递增剪枝率,实验结果如下,可见在剪枝率为85%时还有很好的表现效果。

prune rate:0.0, no zero param:11178451, err:4.16
prune rate:5.0, no zero param:10620490, err:4.64
prune rate:10.0, no zero param:10062528, err:4.66
prune rate:15.0, no zero param:9504567, err:4.69
prune rate:20.0, no zero param:8946605, err:4.5
prune rate:25.0, no zero param:8388644, err:4.52
prune rate:30.0, no zero param:7830680, err:4.53
prune rate:35.0, no zero param:7272721, err:4.53
prune rate:40.0, no zero param:6714759, err:4.51
prune rate:45.0, no zero param:6156797, err:4.44
prune rate:50.0, no zero param:5598836, err:4.52
prune rate:55.0, no zero param:5040874, err:4.5
prune rate:60.0, no zero param:4482913, err:4.68
prune rate:65.0, no zero param:3924949, err:4.67
prune rate:70.0, no zero param:3366990, err:4.64
prune rate:75.0, no zero param:2809028, err:4.67
prune rate:80.0, no zero param:2251067, err:5.01
prune rate:85.0, no zero param:1693105, err:5.46
prune rate:90.0, no zero param:1135143, err:8.22
prune rate:95.0, no zero param:577182, err:31.93

下面是关键部分实验,使用 torchstat 测试剪枝率为85%的模型,结果如下,通过结果发现,模型的参数量、计算量没有发生任何改变,即使权重非零参数只有1.7M,只有最初的11.17M的十分之一。

                          module name  input shape output shape      params memory(MB)             MAdd          Flops  MemRead(B)  MemWrite(B) duration[%]   MemR+W(B)
0                   conv_bn_relu.conv    3  32  32   64  32  32      1728.0       0.25      3,473,408.0    1,769,472.0     19200.0     262144.0       0.00%    281344.0
1                     conv_bn_relu.bn   64  32  32   64  32  32       128.0       0.25        262,144.0      131,072.0    262656.0     262144.0       0.00%    524800.0
2                   conv_bn_relu.relu   64  32  32   64  32  32         0.0       0.25         65,536.0       65,536.0    262144.0     262144.0       0.00%    524288.0
3                 layer1.0.conv1.conv   64  32  32   64  32  32     36864.0       0.25     75,431,936.0   37,748,736.0    409600.0     262144.0      89.46%    671744.0
4                   layer1.0.conv1.bn   64  32  32   64  32  32       128.0       0.25        262,144.0      131,072.0    262656.0     262144.0       0.00%    524800.0
5                 layer1.0.conv1.relu   64  32  32   64  32  32         0.0       0.25         65,536.0       65,536.0    262144.0     262144.0       0.00%    524288.0
6                 layer1.0.conv2.conv   64  32  32   64  32  32     36864.0       0.25     75,431,936.0   37,748,736.0    409600.0     262144.0       0.00%    671744.0
7                   layer1.0.conv2.bn   64  32  32   64  32  32       128.0       0.25        262,144.0      131,072.0    262656.0     262144.0       0.00%    524800.0
8                 layer1.0.conv2.relu   64  32  32   64  32  32         0.0       0.25              0.0            0.0         0.0          0.0       0.00%         0.0
9                   layer1.0.shortcut   64  32  32   64  32  32         0.0       0.25              0.0            0.0         0.0          0.0       0.00%         0.0
10                layer1.1.conv1.conv   64  32  32   64  32  32     36864.0       0.25     75,431,936.0   37,748,736.0    409600.0     262144.0       0.00%    671744.0
11                  layer1.1.conv1.bn   64  32  32   64  32  32       128.0       0.25        262,144.0      131,072.0    262656.0     262144.0       0.00%    524800.0
12                layer1.1.conv1.relu   64  32  32   64  32  32         0.0       0.25         65,536.0       65,536.0    262144.0     262144.0       0.00%    524288.0
13                layer1.1.conv2.conv   64  32  32   64  32  32     36864.0       0.25     75,431,936.0   37,748,736.0    409600.0     262144.0       0.00%    671744.0
14                  layer1.1.conv2.bn   64  32  32   64  32  32       128.0       0.25        262,144.0      131,072.0    262656.0     262144.0       0.00%    524800.0
15                layer1.1.conv2.relu   64  32  32   64  32  32         0.0       0.25              0.0            0.0         0.0          0.0       0.00%         0.0
16                  layer1.1.shortcut   64  32  32   64  32  32         0.0       0.25              0.0            0.0         0.0          0.0       0.00%         0.0
17                layer2.0.conv1.conv   64  32  32  128  16  16     73728.0       0.12     37,715,968.0   18,874,368.0    557056.0     131072.0       0.00%    688128.0
18                  layer2.0.conv1.bn  128  16  16  128  16  16       256.0       0.12        131,072.0       65,536.0    132096.0     131072.0       0.00%    263168.0
19                layer2.0.conv1.relu  128  16  16  128  16  16         0.0       0.12         32,768.0       32,768.0    131072.0     131072.0       0.00%    262144.0
20                layer2.0.conv2.conv  128  16  16  128  16  16    147456.0       0.12     75,464,704.0   37,748,736.0    720896.0     131072.0       0.00%    851968.0
21                  layer2.0.conv2.bn  128  16  16  128  16  16       256.0       0.12        131,072.0       65,536.0    132096.0     131072.0       0.00%    263168.0
22                layer2.0.conv2.relu  128  16  16  128  16  16         0.0       0.12              0.0            0.0         0.0          0.0       0.00%         0.0
23     layer2.0.shortcut.conv_bn.conv   64  32  32  128  16  16      8192.0       0.12      4,161,536.0    2,097,152.0    294912.0     131072.0       0.00%    425984.0
24       layer2.0.shortcut.conv_bn.bn  128  16  16  128  16  16       256.0       0.12        131,072.0       65,536.0    132096.0     131072.0       0.00%    263168.0
25     layer2.0.shortcut.conv_bn.relu  128  16  16  128  16  16         0.0       0.12              0.0            0.0         0.0          0.0       0.00%         0.0
26                layer2.1.conv1.conv  128  16  16  128  16  16    147456.0       0.12     75,464,704.0   37,748,736.0    720896.0     131072.0       9.69%    851968.0
27                  layer2.1.conv1.bn  128  16  16  128  16  16       256.0       0.12        131,072.0       65,536.0    132096.0     131072.0       0.00%    263168.0
28                layer2.1.conv1.relu  128  16  16  128  16  16         0.0       0.12         32,768.0       32,768.0    131072.0     131072.0       0.00%    262144.0
29                layer2.1.conv2.conv  128  16  16  128  16  16    147456.0       0.12     75,464,704.0   37,748,736.0    720896.0     131072.0       0.00%    851968.0
30                  layer2.1.conv2.bn  128  16  16  128  16  16       256.0       0.12        131,072.0       65,536.0    132096.0     131072.0       0.00%    263168.0
31                layer2.1.conv2.relu  128  16  16  128  16  16         0.0       0.12              0.0            0.0         0.0          0.0       0.00%         0.0
32                  layer2.1.shortcut  128  16  16  128  16  16         0.0       0.12              0.0            0.0         0.0          0.0       0.00%         0.0
33                layer3.0.conv1.conv  128  16  16  256   8   8    294912.0       0.06     37,732,352.0   18,874,368.0   1310720.0      65536.0       0.00%   1376256.0
34                  layer3.0.conv1.bn  256   8   8  256   8   8       512.0       0.06         65,536.0       32,768.0     67584.0      65536.0       0.00%    133120.0
47                layer3.1.conv2.relu  256   8   8  256   8   8         0.0       0.06              0.0            0.0         0.0          0.0       0.00%         0.0
48                  layer3.1.shortcut  256   8   8  256   8   8         0.0       0.06              0.0            0.0         0.0          0.0       0.00%         0.0
49                layer4.0.conv1.conv  256   8   8  512   4   4   1179648.0       0.03     37,740,544.0   18,874,368.0   4784128.0      32768.0       0.00%   4816896.0
50                  layer4.0.conv1.bn  512   4   4  512   4   4      1024.0       0.03         32,768.0       16,384.0     36864.0      32768.0       0.00%     69632.0
51                layer4.0.conv1.relu  512   4   4  512   4   4         0.0       0.03          8,192.0        8,192.0     32768.0      32768.0       0.00%     65536.0
52                layer4.0.conv2.conv  512   4   4  512   4   4   2359296.0       0.03     75,489,280.0   37,748,736.0   9469952.0      32768.0       0.00%   9502720.0
53                  layer4.0.conv2.bn  512   4   4  512   4   4      1024.0       0.03         32,768.0       16,384.0     36864.0      32768.0       0.84%     69632.0
54                layer4.0.conv2.relu  512   4   4  512   4   4         0.0       0.03              0.0            0.0         0.0          0.0       0.00%         0.0
55     layer4.0.shortcut.conv_bn.conv  256   8   8  512   4   4    131072.0       0.03      4,186,112.0    2,097,152.0    589824.0      32768.0       0.00%    622592.0
56       layer4.0.shortcut.conv_bn.bn  512   4   4  512   4   4      1024.0       0.03         32,768.0       16,384.0     36864.0      32768.0       0.00%     69632.0
57     layer4.0.shortcut.conv_bn.relu  512   4   4  512   4   4         0.0       0.03              0.0            0.0         0.0          0.0       0.00%         0.0
58                layer4.1.conv1.conv  512   4   4  512   4   4   2359296.0       0.03     75,489,280.0   37,748,736.0   9469952.0      32768.0       0.00%   9502720.0
59                  layer4.1.conv1.bn  512   4   4  512   4   4      1024.0       0.03         32,768.0       16,384.0     36864.0      32768.0       0.00%     69632.0
60                layer4.1.conv1.relu  512   4   4  512   4   4         0.0       0.03          8,192.0        8,192.0     32768.0      32768.0       0.00%     65536.0
61                layer4.1.conv2.conv  512   4   4  512   4   4   2359296.0       0.03     75,489,280.0   37,748,736.0   9469952.0      32768.0       0.00%   9502720.0
62                  layer4.1.conv2.bn  512   4   4  512   4   4      1024.0       0.03         32,768.0       16,384.0     36864.0      32768.0       0.00%     69632.0
63                layer4.1.conv2.relu  512   4   4  512   4   4         0.0       0.03              0.0            0.0         0.0          0.0       0.00%         0.0
64                  layer4.1.shortcut  512   4   4  512   4   4         0.0       0.03              0.0            0.0         0.0          0.0       0.00%         0.0
65                             linear          512           10      5130.0       0.00         10,230.0        5,120.0     22568.0         40.0       0.00%     22608.0
total                                                            11173962.0       7.75  1,112,999,926.0  556,962,816.0     22568.0         40.0     100.00%  57227600.0
=======================================================================================================================================================================
Total params: 11,173,962
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------
Total memory: 7.75MB
Total MAdd: 1.11GMAdd
Total Flops: 556.96MFlops
Total MemR+W: 54.58MB

总结

通过对该篇论文的复现,深刻体会到了非结构化剪枝的绝望之处,0参数的数量不会对模型实际作用产生任何正面作用,这也是多篇论文提到为何稀疏矩阵无法产生实际加速效果,稀疏矩阵无法利用BLAS库,。若想加速稀疏矩阵,则需要配套相应的硬件和软件驱动,这是一个大工程,本人就不继续研究了。
不过本篇论文的三步走思想,确实是经典,在结构化剪枝中也发挥着重要的作用,后续会继续更新结构化剪枝相关论文的复现实验。

你可能感兴趣的:(论文复现,论文精读)