模型剪枝实操

性能优越的深度学习模型通常都有很大的参数量以及冗余的参数量,这导致模型很难部署,相反,生物神经网络都是用的是有效的稀疏连接,按照参数的重要性来减少压缩参数,可以有效地降低参数的存储量、消耗的计算量以及硬件的电量。本文主要是教大家如何使用torch中的prune工具将参数稀疏化(torch.nn.utils.prune)

需求环境

torch>1.4.0版本才具有该功能
需要加载的第三方库如下方代码所示:

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

创建模型

在本例中,以LeNet为例:
注意要是已经训练好的模型

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

检查模型

检查以下LeNet中没有剪枝的层conv1,它包含两种参数weights和bias,且没有缓冲区。

print(list(module.named_parameters()))
Out:

[('weight', Parameter containing:
tensor([[[[ 0.3161, -0.2212,  0.0417],
          [ 0.2488,  0.2415,  0.2071],
          [-0.2412, -0.2400, -0.2016]]],


        [[[ 0.0419,  0.3322, -0.2106],
          [ 0.1776, -0.1845, -0.3134],
          [-0.0708,  0.1921,  0.3095]]],


        [[[-0.2070,  0.0723,  0.2876],
          [ 0.2209,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],


        [[[-0.2799, -0.1527, -0.0388],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],


        [[[ 0.2035,  0.2245, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.3146, -0.2145, -0.1947]]],


        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.1067],
          [ 0.2159, -0.1725,  0.0723]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
Out:

[]

对模型进行剪枝

要对conv1进行剪枝,首先要确定一种剪枝方法(通过torch.nn.utils.prune指定,也可以通过使用自己创建的剪枝方法BasePruningMethod),然后,指定剪枝模块的名称和模块中要剪枝的参数的名称,然后,指定一些剪枝函数中的特定参数。
在本例中,我们要随机裁剪30% 的连接参数(命名为weight),name字段模块中的参数名,amout代表要裁剪的百分比,是0~1之前的浮点型。

prune.random_unstructured(module, name="weight", amount=0.3)

剪枝的行为就是从参数中移除weight,并且替换weight_orig 成一个新参数名 (添加_orig 字段到最初的参数name),weight_org保存了原始的没有剪枝的张量,bias没有被修剪,所以它保持不变。

print(list(module.named_parameters()))
Out:

[('bias', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.3161, -0.2212,  0.0417],
          [ 0.2488,  0.2415,  0.2071],
          [-0.2412, -0.2400, -0.2016]]],


        [[[ 0.0419,  0.3322, -0.2106],
          [ 0.1776, -0.1845, -0.3134],
          [-0.0708,  0.1921,  0.3095]]],


        [[[-0.2070,  0.0723,  0.2876],
          [ 0.2209,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],


        [[[-0.2799, -0.1527, -0.0388],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],


        [[[ 0.2035,  0.2245, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.3146, -0.2145, -0.1947]]],


        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.1067],
          [ 0.2159, -0.1725,  0.0723]]]], device='cuda:0', requires_grad=True))]

剪枝的过程中会生成剪枝mask,叫做weight_mask ,被保存为module buffer的格式(添加_mask字段到原始的参数name)

print(list(module.named_buffers()))
Out:

[('weight_mask', tensor([[[[0., 1., 0.],
          [1., 0., 0.],
          [1., 1., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 0.],
          [1., 0., 1.]]],


        [[[1., 0., 0.],
          [0., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 0.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 1.],
          [0., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 0.],
          [1., 1., 0.]]]], device='cuda:0'))]

为了使前向传输的过程不需要修改,weight属性需要保留,通过将原始参数和mask结合的方式,保存在weight属性中

print(module.weight)
Out:

tensor([[[[ 0.0000, -0.2212,  0.0000],
          [ 0.2488,  0.0000,  0.0000],
          [-0.2412, -0.2400, -0.2016]]],


        [[[ 0.0419,  0.0000, -0.2106],
          [ 0.1776, -0.1845, -0.0000],
          [-0.0708,  0.0000,  0.3095]]],


        [[[-0.2070,  0.0000,  0.0000],
          [ 0.0000,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],


        [[[-0.2799, -0.0000, -0.0000],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],


        [[[ 0.2035,  0.0000, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.0000, -0.2145, -0.1947]]],


        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.0000],
          [ 0.2159, -0.1725,  0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

最后,可以用forward_pre_hooks将剪枝应用到前向传播的每一步,特殊地,当模块已经被剪枝后,每一个被剪枝的参数都需要forward_pre_hook ,在这个例子中,因为我们只需要简直原始参数中的weight,所以只有一个hook

print(module._forward_pre_hooks)
Out:

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f700bd95b00>)])

为了完整性,现在我们对bias也进行剪枝,看一下模块的parameters, buffers, hooks, and attributes 会发生怎样的变化,为了尝试不同的剪枝方法,我们剪枝bias中的L1范数最小的三个最小的条目,部署如下:

prune.l1_unstructured(module, name="bias", amount=3)## 标题

现在原始的参数有weight_orig 和bias_orig,缓存取(buffer)有weight_mask 和bias_mask,被剪枝的两种参数张量将会模块属性存在,此时模块将会两个forward_pre_hooks。

Out:

[('weight_orig', Parameter containing:
tensor([[[[ 0.3161, -0.2212,  0.0417],
          [ 0.2488,  0.2415,  0.2071],
          [-0.2412, -0.2400, -0.2016]]],


        [[[ 0.0419,  0.3322, -0.2106],
          [ 0.1776, -0.1845, -0.3134],
          [-0.0708,  0.1921,  0.3095]]],


        [[[-0.2070,  0.0723,  0.2876],
          [ 0.2209,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],


        [[[-0.2799, -0.1527, -0.0388],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],


        [[[ 0.2035,  0.2245, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.3146, -0.2145, -0.1947]]],


        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.1067],
          [ 0.2159, -0.1725,  0.0723]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
Out:

[('weight_mask', tensor([[[[0., 1., 0.],
          [1., 0., 0.],
          [1., 1., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 0.],
          [1., 0., 1.]]],


        [[[1., 0., 0.],
          [0., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 0.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 1.],
          [0., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 0.],
          [1., 1., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]
print(module.bias)
Out:

tensor([-0.0000, -0.0000, -0.2656, -0.1519, -0.0000,  0.1425], device='cuda:0',
       grad_fn=<MulBackward0>)
print(module._forward_pre_hooks)
Out:

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f700bd95b00>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7f700bd959e8>)])

迭代剪枝

同样的参数可以在模型中剪枝多次,可以运用一系列多个剪枝方法(同样会有多个mask),mask的结合可以调用PruningContainer’s compute_mask。例如,继续上面的例子,我们想继续剪枝模型的参数,这次采用结构化的剪枝(根据第0维,对于卷积层来说也就是channel维,conv1的channel维=6),基于channel的L2,可以用ln_structured ,n=2,dim=0;

prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# As we can verify, this will zero out all the connections corresponding to
# 50% (3 out of 6) of the channels, while preserving the action of the
# previous mask.
print(module.weight)
Out:

tensor([[[[ 0.0000, -0.2212,  0.0000],
          [ 0.2488,  0.0000,  0.0000],
          [-0.2412, -0.2400, -0.2016]]],


        [[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[-0.2070,  0.0000,  0.0000],
          [ 0.0000,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],


        [[[-0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[ 0.2035,  0.0000, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.0000, -0.2145, -0.1947]]],

对应的hook就是torch.nn.utils.prune.PruningContainer类型,存储历史对于weight的剪枝。

for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container
Out:

[<torch.nn.utils.prune.RandomUnstructured object at 0x7f700bd95b00>, <torch.nn.utils.prune.LnStructured object at 0x7f700bd9d208>]

序列化已剪枝的模型

所有有关的张量、包括mask和原始的参数(用来计算剪枝好的模型的)都被存储在模型的state_dict ,可以被轻松地序列化和储存。

print(model.state_dict().keys())
Out:

odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

移除剪枝

print(list(module.named_parameters()))
Out:

[('weight_orig', Parameter containing:
tensor([[[[ 0.3161, -0.2212,  0.0417],
          [ 0.2488,  0.2415,  0.2071],
          [-0.2412, -0.2400, -0.2016]]],


        [[[ 0.0419,  0.3322, -0.2106],
          [ 0.1776, -0.1845, -0.3134],
          [-0.0708,  0.1921,  0.3095]]],


        [[[-0.2070,  0.0723,  0.2876],
          [ 0.2209,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],


        [[[-0.2799, -0.1527, -0.0388],
          [-0.2043,  0.1220,  0.1032],
          [-0.0755,  0.1281,  0.1077]]],


        [[[ 0.2035,  0.2245, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.3146, -0.2145, -0.1947]]],


        [[[-0.1426,  0.2370, -0.1089],
          [-0.2491,  0.1282,  0.1067],
          [ 0.2159, -0.1725,  0.0723]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True))]
print(list(module.named_buffers()))
Out:

[('weight_mask', tensor([[[[0., 1., 0.],
          [1., 0., 0.],
          [1., 1., 1.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[1., 0., 0.],
          [0., 1., 1.],
          [1., 1., 1.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[1., 0., 1.],
          [1., 1., 1.],
          [0., 1., 1.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]
print(module.weight)
Out:

tensor([[[[ 0.0000, -0.2212,  0.0000],
          [ 0.2488,  0.0000,  0.0000],
          [-0.2412, -0.2400, -0.2016]]],


        [[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[-0.2070,  0.0000,  0.0000],
          [ 0.0000,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],


        [[[-0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[ 0.2035,  0.0000, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.0000, -0.2145, -0.1947]]],


        [[[-0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

剪枝永久化:
为了剪枝参数永久应用上,要把weight_orig 和weight_mask、forward_pre_hook移除,用torch.nn.utils.prune里面的remove来做,注意,这里相当于是将裁剪后的模型参数应用到模型上

prune.remove(module, 'weight')
print(list(module.named_parameters()))
Out:

[('bias_orig', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0000, -0.2212,  0.0000],
          [ 0.2488,  0.0000,  0.0000],
          [-0.2412, -0.2400, -0.2016]]],


        [[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[-0.2070,  0.0000,  0.0000],
          [ 0.0000,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],


        [[[-0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[ 0.2035,  0.0000, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.0000, -0.2145, -0.1947]]],


        [[[-0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]]]], device='cuda:0', requires_grad=True))]
print(list(module.named_buffers()))
Out:

[('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]

remove之后:

prune.remove(module, 'weight')
print(list(module.named_parameters()))
Out:

[('bias_orig', Parameter containing:
tensor([-0.1214, -0.0749, -0.2656, -0.1519, -0.1021,  0.1425], device='cuda:0',
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0000, -0.2212,  0.0000],
          [ 0.2488,  0.0000,  0.0000],
          [-0.2412, -0.2400, -0.2016]]],


        [[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[-0.2070,  0.0000,  0.0000],
          [ 0.0000,  0.2077,  0.2369],
          [ 0.2108,  0.0861, -0.2279]]],


        [[[-0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000]]],


        [[[ 0.2035,  0.0000, -0.1129],
          [ 0.3257, -0.0385, -0.0115],
          [-0.0000, -0.2145, -0.1947]]],


        [[[-0.0000,  0.0000, -0.0000],
          [-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000]]]], device='cuda:0', requires_grad=True))]
print(list(module.named_buffers()))
Out:

[('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]

剪枝多个参数

具体化剪枝方法后,我们可以应用多个剪枝方法在同一个模型中,如下面这个例子所示:

new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist
Out:

dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])

全局剪枝

目前为止,我们只讲了局部剪枝,在模型中一个一个剪,那么如何一次性剪整个模型呢?如果要去掉模型20%的连接,替代方法就是去掉每层20%的连接,接下来介绍一个每层减少不同比例参数量的方法。

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

下面检查剪枝后参数的稀疏性,每层的剪枝也许不是20%,但是全局的剪枝接近20%。

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)
Out:

Sparsity in conv1.weight: 7.41%
Sparsity in conv2.weight: 9.49%
Sparsity in fc1.weight: 22.00%
Sparsity in fc2.weight: 12.28%
Sparsity in fc3.weight: 9.76%
Global sparsity: 20.00%

定义自己的剪枝算法

通过nn.utils.prune 的BasePruningMethod 来定义,需要定义的函数有__call__, apply_mask, apply, prune, 和remove。根据应用需求,不俗hi以上所有都需要定义,必须要定义的是__init__ 和compute_mask(怎样计算mask),同时,需要要指定运用于全局、结构化的、非结构化的种类,还需要定义在迭代剪枝的过程中怎样combine mask。
剩下的懒得翻译了。。
Let’s assume, for example, that you want to implement a pruning technique that prunes every other entry in a tensor (or – if the tensor has previously been pruned – in the remaining unpruned portion of the tensor). This will be of PRUNING_TYPE=‘unstructured’ because it acts on individual connections in a layer and not on entire units/channels (‘structured’), or across different parameters (‘global’).
下面是例子:

class FooBarPruningMethod(prune.BasePruningMethod):
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask

现在,把参数应用到nn.Module中,需要一些简单的函数具体化方法和应用它。

def foobar_unstructured(module, name):
    """Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module)
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the
    original (unpruned) parameter is stored in a new parameter named
    `name+'_orig'`.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module

    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
    """
    FooBarPruningMethod.apply(module, name)
    return module

下面来试一下:

model = LeNet()
foobar_unstructured(model.fc3, name='bias')

print(model.fc3.bias_mask)
Out:

tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])

需要注意的是!!!!

pytorch的prune模块还在研究阶段,目前还没有实际的功效
也就是以上经过prune及remove的操作只是将部分参数置为0而已,并没有实际的节约内存及加速计算的功效
也就是乘法次数还是一样的!!!!
所以真的要进行剪枝的话,不建议使用pytorch的prune模块

你可能感兴趣的:(模型剪枝实操)