目前大部分最先进的(SOTA)深度学习技术虽然效果好,但由于其模型参数量和计算量过高,难以用于实际部署。而众所周知,生物神经网络使用高效的稀疏连接(生物大脑神经网络balabala啥的都是稀疏连接的),考虑到这一点,为了减少内存、容量和硬件消耗,同时又不牺牲模型预测的精度,在设备上部署轻量级模型,并通过私有的设备上计算以保证隐私,通过减少参数数量来压缩模型的最佳技术非常重要。
稀疏神经网络在预测精度方面可以达到密集神经网络的水平,但由于模型参数量小,理论上来讲推理速度也会快很多。而模型剪枝是一种将密集神经网络训练成稀疏神经网络的方法。
本文将通过学习官方示例教程,介绍如何通过一个简单的实例教程来进行模型剪枝,实践深度学习模型压缩加速。
相关链接
深度学习模型压缩与加速技术(一):参数剪枝
PyTorch模型剪枝实例教程一、非结构化剪枝
PyTorch模型剪枝实例教程二、结构化剪枝
PyTorch模型剪枝实例教程三、多参数与全局剪枝
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
'''搭建类LeNet网络'''
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 单通道图像输入,5×5核尺寸
self.conv1 = nn.Conv2d(1, 3, 5)
self.conv2 = nn.Conv2d(3, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet().to(device=device)
module = model.conv1
print(list(module.named_parameters())) # 6×5×5的weight + 6×1的bias 的参数量
输出:
[('weight', Parameter containing:
tensor([[[[ 0.1473, 0.1251, 0.0492, -0.1375, -0.0781],
[ 0.0446, -0.1328, 0.0227, 0.0141, -0.1751],
[ 0.0253, 0.0313, 0.0391, 0.1607, -0.0716],
[-0.1125, -0.1641, 0.1691, 0.1583, 0.0449],
[-0.0094, -0.1916, 0.1701, 0.0704, 0.0407]]],
[[[-0.1945, 0.0709, 0.1071, 0.0038, -0.0686],
[ 0.0187, 0.0710, -0.0955, -0.0778, 0.1927],
[ 0.1643, 0.0791, 0.1235, 0.0241, -0.0021],
[-0.1124, 0.0246, -0.0349, -0.1561, 0.0178],
[-0.1779, 0.1216, 0.1086, -0.1837, 0.1789]]],
[[[-0.0051, -0.1969, -0.0155, 0.1890, 0.1977],
[-0.0654, 0.1219, 0.0849, -0.1937, -0.0933],
[-0.0409, 0.1344, 0.1688, 0.1917, -0.1727],
[ 0.1380, -0.1413, -0.1483, -0.0711, -0.0648],
[-0.1571, 0.0570, 0.1783, -0.0786, 0.1367]]]], requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.0346, -0.1446, 0.0633], requires_grad=True))]
剪枝一个模块,需要三步:
这里,我们根据通道的L2范数,沿着张量的第0轴(第0轴对应卷积层的输出通道,conv1的维数为3×5×5)对weight参数进行结构化剪枝,使用ln_structured
()方法。剪枝比例为33%,dim为0,基于L2范数(n=2)
prune.ln_structured(module, name="weight", amount=0.33, n=2, dim=0)
print(module.weight)
输出:
tensor([[[[ 0.1327, 0.0812, -0.0225, -0.0809, 0.1461],
[ 0.1335, -0.1709, 0.0575, -0.1608, -0.0677],
[-0.0397, -0.0982, 0.0654, -0.1030, -0.1656],
[-0.0570, 0.1940, 0.0085, 0.1896, 0.1979],
[-0.0673, 0.0910, -0.0177, -0.1748, 0.1667]]],
[[[-0.0000, 0.0000, 0.0000, 0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, -0.0000, 0.0000],
[ 0.0000, -0.0000, 0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, -0.0000, -0.0000, -0.0000],
[-0.0000, -0.0000, -0.0000, -0.0000, 0.0000]]],
[[[ 0.1668, 0.1742, -0.1581, -0.1208, 0.0745],
[ 0.0459, -0.0275, -0.1190, -0.1631, -0.1956],
[-0.0480, -0.1716, -0.0168, 0.0089, 0.0876],
[-0.0129, -0.1616, -0.1164, -0.1869, -0.1782],
[ 0.0411, -0.0278, -0.1266, 0.1329, -0.1240]]]],
grad_fn=)
所有相关的张量,包括mask缓冲区和用于计算剪枝张量的原始参数都存储在模型的state_dict中,因此,如果需要,可以很容易地序列化和保存。
print(model.state_dict().keys())
输出:
odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
要使修剪永久存在,可以删除weight_trans(orig)和weight_mask重新参数化,并删除forward_pre_hook,可以使用torch.nn.util.prune中的remove
函数。注意,这并没有取消修剪,就像它从未发生过一样。它只是简单地使它永久存在,相反,在它的剪枝版本中,通过将参数的权重重新分配给模型参数。
print(list(module.named_parameters()))
print(list(module.named_buffers()))
print(model.state_dict().keys())
输出:
[('bias', Parameter containing:
tensor([0.1673, 0.0794, 0.0110], requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 1.2122e-01, -9.2122e-02, 6.4242e-02, 2.0936e-02, 7.7185e-03],
[ 1.6201e-01, -1.2338e-01, -1.2014e-01, 3.0895e-03, -3.8402e-02],
[ 7.5407e-03, 1.9274e-01, -3.0035e-02, 1.9638e-02, -5.5985e-03],
[-1.2915e-01, -7.7561e-02, 6.8224e-02, -1.8743e-01, -1.6051e-01],
[ 1.4066e-01, 1.1038e-01, -1.8010e-01, 9.4039e-02, -1.2981e-01]]],
[[[-1.3836e-01, -1.8937e-01, 3.2540e-02, -6.2541e-02, 1.6695e-01],
[ 1.3803e-01, 1.0196e-01, 8.2551e-02, -1.2815e-06, -1.4024e-02],
[-3.7121e-02, -1.8625e-01, 4.1115e-02, -1.5329e-01, 3.8362e-02],
[-5.7373e-02, 9.3459e-02, 5.9365e-02, -9.4975e-02, 1.7842e-01],
[ 2.2319e-02, -5.2064e-02, -1.9440e-01, -1.7895e-03, 8.3709e-02]]],
[[[ 1.4024e-01, 6.4016e-02, 1.6549e-01, 9.6163e-02, 1.8803e-01],
[-5.8840e-02, -1.8487e-01, 1.8037e-01, 7.3717e-02, 1.9991e-01],
[ 7.9629e-02, -1.1025e-01, 1.2504e-01, 4.6581e-02, 2.2388e-04],
[-3.6367e-02, 9.8296e-02, 6.5209e-02, 1.7801e-01, 1.3420e-01],
[ 1.4725e-01, -1.9269e-01, 1.9282e-02, -1.3924e-01, -6.2607e-02]]]],
requires_grad=True))]
[('weight_mask', tensor([[[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]],
[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]],
[[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]]]))]
odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
使用remove
函数后
prune.remove(module, 'weight')
print(list(module.named_parameters()))
print(list(module.named_buffers()))
print(model.state_dict().keys())
输出:
[('bias', Parameter containing:
tensor([ 0.1144, -0.1641, 0.0962], requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0142, 0.1698, -0.0730, -0.0358, -0.1309],
[ 0.1520, 0.1900, -0.0843, 0.0950, 0.1674],
[-0.1724, 0.1453, -0.1764, 0.0345, -0.1767],
[ 0.0727, 0.1170, 0.1585, -0.0713, -0.0158],
[ 0.1485, -0.0270, -0.0164, 0.0889, 0.1170]]],
[[[ 0.0000, -0.0000, -0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, -0.0000, 0.0000, -0.0000],
[ 0.0000, 0.0000, -0.0000, -0.0000, 0.0000],
[ 0.0000, 0.0000, -0.0000, -0.0000, -0.0000],
[-0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
[[[-0.1681, 0.1801, -0.0567, -0.0366, 0.0085],
[ 0.0495, 0.0320, -0.0127, -0.1761, -0.0948],
[ 0.1340, 0.1103, 0.1332, -0.1911, -0.1225],
[ 0.0781, -0.0920, -0.1759, 0.0977, 0.0030],
[-0.0436, -0.1694, -0.0094, -0.0553, -0.0591]]]], requires_grad=True))]
[]
odict_keys(['conv1.bias', 'conv1.weight', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
可以发现,缓冲区中保存的mask参数没了
本示例首先搭建了一个类LeNet网络模型,为了进行结构化剪枝,我们选取了LeNet的conv1模块,该模块参数包含为3×5×5的weight卷积核参数和3×1的bias参数,通过示例,我们利用torch.nn.prune中的ln_structured剪枝方法,实现了对weight的3个通道中其中一个通道进行了L2 norm结构化剪枝。
本文用到的核心函数方法:
name
的参数将保持永久修剪,而名为name+_trans
(orig)的参数将从参数列表中删除。类似地,名为name+_mask
的缓冲区将从缓冲区中删除。参考:
Torch官方剪枝教程