PyTorch模型减枝技术-pruning

介绍

减枝(prune)是深度学习模型压缩常见的技术之一, 目的是使得CNN/RNN/Transformer等模型的权重weight参数稀疏化 sparsity,即weight包含大量的0元素.
模型稀疏化的优点:

  1. 存储优势: 如果模型weight包含大量的0元素,实际存储中可以采用各种压缩格式,比如COO, CSR
  2. 计算优势: 由于包含大量的0元素, 因此现代的很多加速器比如NPU都设计了跳零单元 zero skipping unit, 减少了计算开销

本节的主要目的是认识并掌握PyTorch中对pruning技术的应用, let's coding!


Requirement

如下环境测试:

  • Ubuntu 20.04
  • PyTorch 1.12

代码实现

模型定义

简单起见,采用LeNet-5

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
# https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

# %%
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16*5*5, 120)  # 5x5 dim
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = LeNet().to(device=device)

# inspect a module / layer
module = model.conv1
print(list(module.named_parameters()))
print(list(module.named_buffers()))

输出第一个Conv2d layer的learanable parameters, weight, bias, prune之前的参数:
输出结果


[('weight', Parameter containing:
tensor([[[[-0.1544,  0.0351,  0.2471],
          [-0.0788,  0.2216, -0.0925],
          [-0.1486, -0.1366, -0.0963]]],


        [[[ 0.2780, -0.1358,  0.2029],
          [ 0.2228,  0.0061, -0.1716],
          [-0.3228, -0.1036,  0.2223]]],


        [[[-0.2228,  0.0742, -0.1789],
          [-0.1888, -0.3132, -0.1999],
          [ 0.0359, -0.1263,  0.2270]]],


        [[[-0.2067, -0.2954, -0.1952],
          [-0.2652,  0.2705, -0.1056],
          [ 0.1010, -0.1888, -0.0087]]],


        [[[-0.1197, -0.0913, -0.2631],
          [-0.2442,  0.2834, -0.0278],
          [ 0.1842,  0.1579, -0.3101]]],


        [[[-0.2317,  0.1837,  0.1096],
          [-0.0636, -0.1924, -0.3029],
          [ 0.1714,  0.1079,  0.0050]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.2799,  0.2889,  0.1455,  0.0563, -0.0082,  0.1129], device='cuda:0',
       requires_grad=True))]
[]

对某一层Layer的Weight, bias Prune

代码如下, 对conv1 layer的weight, bias进行prune
prune.random_unstructured(module=module, name='weight', amount=0.3)

  • random_unstructure: prune 方法, 非结构化减枝, 这种算法简单,但是由于是非结构化,因此对硬件加速不是很友好.
  • name=weight, 代表对weight进行prune, 还可以是bias
  • amount: 减枝的程度, 如果是0~1之间的小数,例如0.3代表30%的weight参数进行减枝; 如果是整数, 例如10代表weight中10个元素减枝为0

# Pruning a Module
# Prune the first Conv layer

prune.random_unstructured(module=module, name='weight', amount=0.3)
# prune之后, 原始的weight被remove, 替换为 weight_orig(原始未prune的weight)
print(list(module.named_parameters()))
print(list(module.named_buffers()))
print(module.weight)
# prune前后的weight shape没有变化, 但是prune之后的weight出现了大量的0元素
# prune对象
print(module._forward_pre_hooks)
prune.l1_unstructured(module=module, name='bias', amount=3)
print(list(module.named_parameters())) # bias_ori
print(list(module.named_buffers()))
print(module.bias)
print(module._forward_pre_hooks)

对多个Layer 进行Prune

例如对net中所有的Conv2d, Linear layer进行Prune, 直接遍历layers



---
# 多层prune
# Conv2d, Linear进行Prune
new_model = LeNet()

for name, module in new_model.named_modules():
    if isinstance(module, nn.Conv2d):
        prune.l1_unstructured(module=module, name='weight', amount=0.2)
    elif isinstance(module, nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())

Random_unstructed Prune
此函数是PyTorch中已经实现的prune方法之一, 非结构化随机减枝

def random_unstructured(module, name, amount):
    r"""Prunes tensor corresponding to parameter called ``name`` in ``module``
    by removing the specified ``amount`` of (currently unpruned) units
    selected at random.
    Modifies module in place (and also return the modified module) by:

    1) adding a named buffer called ``name+'_mask'`` corresponding to the
       binary mask applied to the parameter ``name`` by the pruning method.
    2) replacing the parameter ``name`` by its pruned version, while the
       original (unpruned) parameter is stored in a new parameter named
       ``name+'_orig'``.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (str): parameter name within ``module`` on which pruning
                will act.
        amount (int or float): quantity of parameters to prune.
            If ``float``, should be between 0.0 and 1.0 and represent the
            fraction of parameters to prune. If ``int``, it represents the
            absolute number of parameters to prune.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input module

    Examples:
        >>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1)
        >>> torch.sum(m.weight_mask == 0)
        tensor(1)

    """
    RandomUnstructured.apply(module, name, amount)
    return module

你可能感兴趣的:(PyTorch模型减枝技术-pruning)