pytorch如何使用自带的模型剪枝工具prune

剪枝教程

原链接:Pruning Tutorial — PyTorch Tutorials 1.12.1+cu102 documentation

目录

摘要

前提需要

创建模型

检视模块

对一个模型进行剪枝

迭代剪枝

序列化修剪过的模型

修剪再参量化

修剪模型中的多个参数

全局剪枝

使用自定义修剪功能扩展torch.nn.utils.prune


摘要

        最先进的深度学习技术依赖于难以部署的过度参数化模型。相反,已知生物神经网络使用高效的稀疏连接。为了在不牺牲精度的情况下减少内存、电池和硬件的消耗,在设备上部署轻量级模型,并通过私有设备上的计算保证隐私性,确定通过减少模型中的参数数量来压缩模型的最佳技术是很重要的。在研究方面,修剪被用于研究过度参数化和欠参数化网络之间学习动态的差异,研究幸运稀疏子网络和初始化(彩票)作为破坏性神经结构搜索技术的作用,等等。
        在本教程中,您将学习如何使用torch.nn.utils.prune来稀疏化您的神经网络,以及如何扩展它来实现您自己的自定义修剪技术。

前提需要

"torch>=1.4.0a0+8e8a5e0"

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

创建模型

在本教程中,我们使用LeCun等人1998年的LeNet体系结构。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

检视模块

让我们检查一下LeNet模型中的(未修剪的)conv1层。它将包含两个参数权重和偏差,目前没有缓冲区。

module = model.conv1
print(list(module.named_parameters()))

OUT:

[('weight', Parameter containing:
tensor([[[[-0.0629,  0.3243, -0.0095],
          [ 0.2714,  0.0487,  0.1099],
          [-0.2398, -0.0376,  0.0314]]],


        [[[ 0.0646, -0.1846, -0.0758],
          [-0.1228,  0.3297, -0.1311],
          [ 0.0151,  0.0279, -0.0924]]],


        [[[ 0.0007, -0.1994,  0.0332],
          [ 0.0586, -0.2853,  0.2788],
          [-0.1882,  0.0664, -0.2495]]],


        [[[ 0.1629,  0.2119,  0.1293],
          [-0.0994,  0.0930,  0.0096],
          [ 0.3129, -0.1592, -0.0972]]],


        [[[-0.1942,  0.1310,  0.2086],
          [ 0.1912,  0.0933,  0.3324],
          [ 0.1050,  0.2687,  0.3313]]],


        [[[-0.3087,  0.0690, -0.3148],
          [ 0.0454, -0.2180,  0.0274],
          [ 0.3015,  0.3015, -0.1018]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.1067, -0.1336, -0.1805, -0.1301, -0.0989,  0.2122], device='cuda:0',
       requires_grad=True))]

print(list(module.named_buffers()))

OUT:

[]

对一个模型进行剪枝

        要修剪一个模块(在本例中是LeNet架构的conv1层),首先从torch.nn.utils.prune中选择一种修剪技术(或者通过子类化BasePruningMethod实现自己的修剪技术)。然后,指定要在该模块中删除的模块和参数的名称。最后,使用所选修剪技术所需的适当关键字参数,指定修剪参数。
        在本例中,我们将在conv1层中随机删除名为weight的参数中的30%的连接。模块作为函数的第一个参数传递;Name使用它的字符串标识符标识模块中的参数;amount表示要修剪的连接的百分比(如果是0之间的浮动)。和1.),或要修剪的连接的绝对数量(如果它是非负整数)。

prune.random_unstructured(module, name="weight", amount=0.3)

OUT:

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

        修剪的方法是从参数中删除权重,并用一个名为weight_trans的新参数替换它(即在初始参数名后追加" _trans ")。weight_trans存储了张量的未修剪版本。偏见没有被消除,所以它将保持不变。

print(list(module.named_parameters()))

OUT:

[('bias', Parameter containing:
tensor([ 0.1067, -0.1336, -0.1805, -0.1301, -0.0989,  0.2122], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-0.0629,  0.3243, -0.0095],
          [ 0.2714,  0.0487,  0.1099],
          [-0.2398, -0.0376,  0.0314]]],


        [[[ 0.0646, -0.1846, -0.0758],
          [-0.1228,  0.3297, -0.1311],
          [ 0.0151,  0.0279, -0.0924]]],


        [[[ 0.0007, -0.1994,  0.0332],
          [ 0.0586, -0.2853,  0.2788],
          [-0.1882,  0.0664, -0.2495]]],


        [[[ 0.1629,  0.2119,  0.1293],
          [-0.0994,  0.0930,  0.0096],
          [ 0.3129, -0.1592, -0.0972]]],


        [[[-0.1942,  0.1310,  0.2086],
          [ 0.1912,  0.0933,  0.3324],
          [ 0.1050,  0.2687,  0.3313]]],


        [[[-0.3087,  0.0690, -0.3148],
          [ 0.0454, -0.2180,  0.0274],
          [ 0.3015,  0.3015, -0.1018]]]], device='cuda:0', requires_grad=True))]

        通过上述选择的修剪技术生成的修剪掩码被保存为名为weight_mask的模块缓冲区(即在初始参数名后附加“_mask”)。

print(list(module.named_buffers()))

[('weight_mask', tensor([[[[1., 1., 1.],
          [1., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [1., 0., 1.],
          [0., 0., 1.]]],


        [[[0., 1., 1.],
          [1., 1., 1.],
          [1., 0., 0.]]],


        [[[1., 1., 0.],
          [0., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [0., 1., 0.],
          [1., 1., 0.]]]], device='cuda:0'))]

        为了使转发传递不需要修改就能工作,weight属性需要存在。在torch.n .utils.prune中实现的修剪技术计算经过修剪的权重(通过将掩码与原始参数结合起来),并将它们存储在属性权重中。注意,这不再是模块的参数,它现在只是一个属性。

print(module.weight)

OUT:

tensor([[[[-0.0629,  0.3243, -0.0095],
          [ 0.2714,  0.0000,  0.1099],
          [-0.2398, -0.0376,  0.0314]]],


        [[[ 0.0646, -0.1846, -0.0000],
          [-0.1228,  0.0000, -0.1311],
          [ 0.0000,  0.0000, -0.0924]]],


        [[[ 0.0000, -0.1994,  0.0332],
          [ 0.0586, -0.2853,  0.2788],
          [-0.1882,  0.0000, -0.0000]]],


        [[[ 0.1629,  0.2119,  0.0000],
          [-0.0000,  0.0000,  0.0096],
          [ 0.3129, -0.1592, -0.0972]]],


        [[[-0.1942,  0.1310,  0.2086],
          [ 0.1912,  0.0000,  0.3324],
          [ 0.1050,  0.2687,  0.3313]]],


        [[[-0.3087,  0.0690, -0.0000],
          [ 0.0000, -0.2180,  0.0000],
          [ 0.3015,  0.3015, -0.0000]]]], device='cuda:0',
       grad_fn=)

        最后,使用PyTorch的forward_pre_hooks在每次向前传递之前应用修剪。具体来说,当模块被修剪时,就像我们在这里所做的那样,它将为被修剪的与它相关的每个参数获取一个forward_pre_hook。在本例中,因为到目前为止我们只修剪了名为weight的原始参数,所以只会出现一个钩子。

print(module._forward_pre_hooks)

OUT:

OrderedDict([(0, )])

        为了完整起见,我们现在也可以修剪偏差,以查看模块的参数、缓冲区、钩子和属性是如何变化的。只是为了尝试另一种修剪技术,在这里我们按L1范数修剪偏差中最小的3个条目,正如在l1_unstructured修剪函数中实现的那样。

prune.l1_unstructured(module, name="bias", amount=3)

OUT:

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))

        现在,我们希望命名参数同时包含weight_trans(以前的)和bias_trans。缓冲区将包括weight_mask和bias_mask。两个张量的裁剪版本将作为模块属性存在,模块现在将有两个forward_pre_hooks。

print(list(module.named_parameters()))

OUT:

[('weight_orig', Parameter containing:
tensor([[[[-0.0629,  0.3243, -0.0095],
          [ 0.2714,  0.0487,  0.1099],
          [-0.2398, -0.0376,  0.0314]]],


        [[[ 0.0646, -0.1846, -0.0758],
          [-0.1228,  0.3297, -0.1311],
          [ 0.0151,  0.0279, -0.0924]]],


        [[[ 0.0007, -0.1994,  0.0332],
          [ 0.0586, -0.2853,  0.2788],
          [-0.1882,  0.0664, -0.2495]]],


        [[[ 0.1629,  0.2119,  0.1293],
          [-0.0994,  0.0930,  0.0096],
          [ 0.3129, -0.1592, -0.0972]]],


        [[[-0.1942,  0.1310,  0.2086],
          [ 0.1912,  0.0933,  0.3324],
          [ 0.1050,  0.2687,  0.3313]]],


        [[[-0.3087,  0.0690, -0.3148],
          [ 0.0454, -0.2180,  0.0274],
          [ 0.3015,  0.3015, -0.1018]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([ 0.1067, -0.1336, -0.1805, -0.1301, -0.0989,  0.2122], device='cuda:0',
       requires_grad=True))]

print(list(module.named_buffers()))

OUT:

[('weight_mask', tensor([[[[1., 1., 1.],
          [1., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [1., 0., 1.],
          [0., 0., 1.]]],


        [[[0., 1., 1.],
          [1., 1., 1.],
          [1., 0., 0.]]],


        [[[1., 1., 0.],
          [0., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [0., 1., 0.],
          [1., 1., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 1., 1., 0., 0., 1.], device='cuda:0'))]

print(module.bias)

OUT:

tensor([ 0.0000, -0.1336, -0.1805, -0.0000, -0.0000,  0.2122], device='cuda:0',
       grad_fn=)

print(module._forward_pre_hooks)

OUT:

OrderedDict([(0, ), (1, )])

迭代剪枝

        一个模块中的同一个参数可以被多次修剪,各种修剪调用的效果等于一系列应用的各种掩码的组合。新掩码与旧掩码的组合由PruningContainer的compute_mask方法处理。
        比如说,我们现在想要进一步削减模块。权重,这一次使用沿张量第0轴(第0轴对应卷积层的输出通道,对于conv1具有6维数)的结构化修剪,基于通道L2范数。这可以使用ln_structured函数实现,其中n=2, dim=0。

prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# As we can verify, this will zero out all the connections corresponding to
# 50% (3 out of 6) of the channels, while preserving the action of the
# previous mask.
print(module.weight)

OUT:

tensor([[[[-0.0629,  0.3243, -0.0095],
          [ 0.2714,  0.0000,  0.1099],
          [-0.2398, -0.0376,  0.0314]]],


        [[[ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000]]],


        [[[ 0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000]]],


        [[[ 0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000]]],


        [[[-0.1942,  0.1310,  0.2086],
          [ 0.1912,  0.0000,  0.3324],
          [ 0.1050,  0.2687,  0.3313]]],


        [[[-0.3087,  0.0690, -0.0000],
          [ 0.0000, -0.2180,  0.0000],
          [ 0.3015,  0.3015, -0.0000]]]], device='cuda:0',
       grad_fn=)

        对应的钩子现在的类型是torch.nn.utils.prune。PruningContainer,它将存储应用于weight参数的修剪历史。

for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container

OUT:

[, ]

序列化修剪过的模型

        所有相关的张量,包括掩码缓冲区和用于计算修剪张量的原始参数都存储在模型的state_dict中,因此,如果需要,可以很容易地序列化和保存。

print(model.state_dict().keys())

OUT:

odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

修剪再参量化

        要使修剪永久化,需要删除weight_trans和weight_mask方面的重新参数化,并删除forward_pre_hook,我们可以使用torch.n .utils.prune中的remove功能。注意,这不会撤销修剪,就好像它从未发生过一样。相反,它只是通过在裁剪版中为模型参数重新分配参数权重,使其永久存在。

在删除重新参数化之前:

print(list(module.named_parameters()))

OUT:

[('weight_orig', Parameter containing:
tensor([[[[-0.0629,  0.3243, -0.0095],
          [ 0.2714,  0.0487,  0.1099],
          [-0.2398, -0.0376,  0.0314]]],


        [[[ 0.0646, -0.1846, -0.0758],
          [-0.1228,  0.3297, -0.1311],
          [ 0.0151,  0.0279, -0.0924]]],


        [[[ 0.0007, -0.1994,  0.0332],
          [ 0.0586, -0.2853,  0.2788],
          [-0.1882,  0.0664, -0.2495]]],


        [[[ 0.1629,  0.2119,  0.1293],
          [-0.0994,  0.0930,  0.0096],
          [ 0.3129, -0.1592, -0.0972]]],


        [[[-0.1942,  0.1310,  0.2086],
          [ 0.1912,  0.0933,  0.3324],
          [ 0.1050,  0.2687,  0.3313]]],


        [[[-0.3087,  0.0690, -0.3148],
          [ 0.0454, -0.2180,  0.0274],
          [ 0.3015,  0.3015, -0.1018]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([ 0.1067, -0.1336, -0.1805, -0.1301, -0.0989,  0.2122], device='cuda:0',
       requires_grad=True))]

print(list(module.named_buffers()))

OUT:

[('weight_mask', tensor([[[[1., 1., 1.],
          [1., 0., 1.],
          [1., 1., 1.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[1., 1., 1.],
          [1., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 0.],
          [0., 1., 0.],
          [1., 1., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 1., 1., 0., 0., 1.], device='cuda:0'))]

print(module.weight)

OUT:

tensor([[[[-0.0629,  0.3243, -0.0095],
          [ 0.2714,  0.0000,  0.1099],
          [-0.2398, -0.0376,  0.0314]]],


        [[[ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000]]],


        [[[ 0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000]]],


        [[[ 0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000]]],


        [[[-0.1942,  0.1310,  0.2086],
          [ 0.1912,  0.0000,  0.3324],
          [ 0.1050,  0.2687,  0.3313]]],


        [[[-0.3087,  0.0690, -0.0000],
          [ 0.0000, -0.2180,  0.0000],
          [ 0.3015,  0.3015, -0.0000]]]], device='cuda:0',
       grad_fn=)

去除重新参数化后:

prune.remove(module, 'weight')
print(list(module.named_parameters()))

OUT:

[('bias_orig', Parameter containing:
tensor([ 0.1067, -0.1336, -0.1805, -0.1301, -0.0989,  0.2122], device='cuda:0',
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[-0.0629,  0.3243, -0.0095],
          [ 0.2714,  0.0000,  0.1099],
          [-0.2398, -0.0376,  0.0314]]],


        [[[ 0.0000, -0.0000, -0.0000],
          [-0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000]]],


        [[[ 0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000]]],


        [[[ 0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000]]],


        [[[-0.1942,  0.1310,  0.2086],
          [ 0.1912,  0.0000,  0.3324],
          [ 0.1050,  0.2687,  0.3313]]],


        [[[-0.3087,  0.0690, -0.0000],
          [ 0.0000, -0.2180,  0.0000],
          [ 0.3015,  0.3015, -0.0000]]]], device='cuda:0', requires_grad=True))]

print(list(module.named_buffers()))

OUT:

[('bias_mask', tensor([0., 1., 1., 0., 0., 1.], device='cuda:0'))]

修剪模型中的多个参数

        通过指定所需的修剪技术和参数,我们可以很容易地修剪网络中的多个张量,可能是根据它们的类型,就像我们在这个例子中看到的那样。

new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

OUT:

dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])

全局剪枝

        到目前为止,我们只研究了通常被称为局部修剪的方法,即逐个修剪模型中的张量,方法是将每个张量的统计数据(权重、幅值、激活、梯度等)专门与该张量中的其他张量进行比较。然而,一种常见且可能更强大的技术是一次性删除模型,通过删除(例如)整个模型中最低的20%的连接,而不是删除每一层中最低的20%的连接。这可能会导致每层不同的修剪百分比。让我们看看如何使用global_unstructured from torch.nn.utils.prune实现这一点。

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

        现在我们可以检查每个修剪参数诱导的稀疏性,每一层都不等于20%。然而,全局稀疏度将是(大约)20%。

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

OUT:

Sparsity in conv1.weight: 1.85%
Sparsity in conv2.weight: 6.83%
Sparsity in fc1.weight: 22.11%
Sparsity in fc2.weight: 11.88%
Sparsity in fc3.weight: 11.55%
Global sparsity: 20.00%

使用自定义修剪功能扩展torch.nn.utils.prune

        要实现您自己的修剪函数,您可以通过继承BasePruningMethod基类来扩展nn.utils.prune模块,与所有其他修剪方法相同。基类为你实现了以下方法:__call__, apply_mask, apply, prune和remove。除了一些特殊情况外,您不需要为新的修剪技术重新实现这些方法。然而,你必须实现__init__(构造函数)和compute_mask(关于如何根据修剪技术的逻辑为给定张量计算掩码的指令)。此外,您还必须指定该技术实现的修剪类型(支持的选项有全局的、结构化的和非结构化的)。这是确定在迭代应用修剪的情况下如何组合遮罩所需要的。换句话说,当对预修剪的参数进行修剪时,当前的修剪技术期望作用于参数的未修剪部分。指定PRUNING_TYPE将使PruningContainer(它处理修剪掩码的迭代应用程序)能够正确识别要修剪的参数片。

        让我们假设,例如,你想实现一种修剪技术,修剪张量中的每一个其他项(或者如果张量的剩余未修剪部分之前已经被修剪过)。这将是PRUNING_TYPE='unstructured',因为它作用于一个层中的单个连接,而不是整个单元/通道('结构化'),或跨不同的参数('全局')。

class FooBarPruningMethod(prune.BasePruningMethod):
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask

        现在,把这个应用到nn中的一个参数上。模块中,您还应该提供一个简单的函数来实例化该方法并应用它。

def foobar_unstructured(module, name):
    """Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module)
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the
    original (unpruned) parameter is stored in a new parameter named
    `name+'_orig'`.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module

    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
    """
    FooBarPruningMethod.apply(module, name)
    return module

让我们试试吧!

model = LeNet()
foobar_unstructured(model.fc3, name='bias')

print(model.fc3.bias_mask)

OUT:

tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])

翻译完毕~~~~~~~

你可能感兴趣的:(剪枝,深度学习,python)