模型剪枝是一种用于神经网络压缩的技术,其主要目的是减少模型的计算复杂性和存储需求,同时尽量保持模型的预测能力。这通常通过删除模型中的冗余信息或减少模型的大小来实现。
剪枝技术主要有以下几种:
重要性剪枝:这种方法首先确定模型中每个权重的重要性,例如可以使用梯度或激活值来判断。然后,删除重要性低的权重,并重新训练模型以调整剩余的权重。
全局剪枝:全局剪枝方法通过对整个网络应用某种全局标准(例如,阈值)来删除权重。这种方法通常在预训练的网络上应用,以减少其大小。
结构化剪枝:这种方法涉及删除网络中的特定层或连接。结构化剪枝通常在预训练的网络上应用,并且可以通过迭代地应用不同的剪枝策略来逐步减小网络的大小。
迭代剪枝:这种方法涉及在每次迭代中应用不同的剪枝策略。例如,可以使用重要性剪枝来删除一些权重,然后使用全局剪枝来进一步减小网络的大小。
混合剪枝:这种方法结合了多种剪枝策略,以实现最佳的压缩效果。例如,可以先使用重要性剪枝来删除一些权重,然后使用全局剪枝来进一步减小网络的大小。
需要注意的是,剪枝技术可能会对模型的性能产生影响,因此需要在压缩模型和保持模型性能之间找到一个平衡点。此外,剪枝后的模型可能需要重新训练以调整剩余的权重。
最先进的深度学习
技术依赖于难以部署的过度参数化模型。相反,生物神经网络已知使用有效的稀疏连接。通过减少模型中的参数数量来确定压缩模型的最佳技术非常重要,这样可以在不牺牲准确性的情况下减少内存、电池和硬件消耗。这反过来又允许您在设备上部署轻量级模型,并通过私有设备上计算来保证隐私。在研究前沿,剪枝用于研究过度参数化和欠参数化网络之间学习动态的差异,研究幸运稀疏子网络和初始化(“彩票”)作为破坏性神经架构搜索技术的作用。
在本教程中,您将学习如何使用torch.nn.utils.prune稀疏神经网络,以及如何扩展它以实现您自己的自定义剪枝技术。
要求
“torch>=1.4.0a0+8e8a5e0”
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
创建模型
在本教程中,我们使用LeCun 等人 (1998) 的LeNet架构。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1 input image channel, 6 output channels, 5x5 square conv kernel
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = LeNet().to(device=device)
检查模块
让我们检查一下conv1LeNet 模型中的(未剪枝的)层。目前它将包含两个参数weight和bias,并且没有缓冲区。
module = model.conv1
print(list(module.named_parameters()))
输出:
[('weight', Parameter containing:
tensor([[[[ 0.1529, 0.1660, -0.0469, 0.1837, -0.0438],
[ 0.0404, -0.0974, 0.1175, 0.1763, -0.1467],
[ 0.1738, 0.0374, 0.1478, 0.0271, 0.0964],
[-0.0282, 0.1542, 0.0296, -0.0934, 0.0510],
[-0.0921, -0.0235, -0.0812, 0.1327, -0.1579]]],
[[[-0.0922, -0.0565, -0.1203, 0.0189, -0.1975],
[ 0.1806, -0.1699, 0.1544, 0.0333, -0.0649],
[ 0.1236, 0.0312, 0.1616, 0.0219, -0.0631],
[ 0.0537, -0.0542, 0.0842, 0.1786, 0.1156],
[-0.0874, 0.1155, 0.0358, 0.1016, -0.1219]]],
[[[-0.1980, -0.0773, -0.1534, 0.1641, 0.0576],
[ 0.0828, 0.0633, -0.0035, 0.1565, -0.1421],
[ 0.0126, -0.1365, 0.0617, -0.0689, 0.0613],
[-0.0417, 0.1659, -0.1185, -0.1193, -0.1193],
[ 0.1799, 0.0667, 0.1925, -0.1651, -0.1984]]],
[[[-0.1565, -0.1345, 0.0810, 0.0716, 0.1662],
[-0.1033, -0.1363, 0.1061, -0.0808, 0.1214],
[-0.0475, 0.1144, -0.1554, -0.1009, 0.0610],
[ 0.0423, -0.0510, 0.1192, 0.1360, -0.1450],
[-0.1068, 0.1831, -0.0675, -0.0709, -0.1935]]],
[[[-0.1145, 0.0500, -0.0264, -0.1452, 0.0047],
[-0.1366, -0.1697, -0.1101, -0.1750, -0.1273],
[ 0.1999, 0.0378, 0.0616, -0.1865, -0.1314],
[-0.0666, 0.0313, -0.1760, -0.0862, -0.1197],
[ 0.0006, -0.0744, -0.0139, -0.1355, -0.1373]]],
[[[-0.1167, -0.0685, -0.1579, 0.1677, -0.0397],
[ 0.1721, 0.0623, -0.1694, 0.1384, -0.0550],
[-0.0767, -0.1660, -0.1988, 0.0572, -0.0437],
[ 0.0779, -0.1641, 0.1485, -0.1468, -0.0345],
[ 0.0418, 0.1033, 0.1615, 0.1822, -0.1586]]]], device='cuda:0',
requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.0503, -0.0860, -0.0219, -0.1497, 0.1822, -0.1468], device='cuda:0',
requires_grad=True))]
print(list(module.named_buffers()))
输出:
[]
剪枝模块
要剪枝模块(在本例中是conv1LeNet 架构的层),首先在可用的剪枝技术中选择一种剪枝技术 torch.nn.utils.prune(或 通过子类化实现 您自己的 技术BasePruningMethod)。然后,指定模块以及要在该模块中删除的参数名称。最后,使用所选剪枝技术所需的适当关键字参数,指定剪枝参数。
weight在此示例中,我们将随机剪枝层中指定参数中 30% 的连接conv1。模块作为第一个参数传递给函数;name 使用其字符串标识符来标识该模块内的参数;并 amount指示要剪枝的连接的百分比(如果它是 0 和 1 之间的浮点数),或者要剪枝的连接的绝对数量(如果它是非负整数)。
prune.random_unstructured(module, name="weight", amount=0.3)
输出:
Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
剪枝通过weight从参数中删除并将其替换为名为 的新参数weight_orig(即附加"_orig"到初始参数name)来进行。weight_orig存储张量的未剪枝版本。没有被剪枝,所以它bias会保持完整。
print(list(module.named_parameters()))
输出:
[('bias', Parameter containing:
tensor([ 0.0503, -0.0860, -0.0219, -0.1497, 0.1822, -0.1468], device='cuda:0',
requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.1529, 0.1660, -0.0469, 0.1837, -0.0438],
[ 0.0404, -0.0974, 0.1175, 0.1763, -0.1467],
[ 0.1738, 0.0374, 0.1478, 0.0271, 0.0964],
[-0.0282, 0.1542, 0.0296, -0.0934, 0.0510],
[-0.0921, -0.0235, -0.0812, 0.1327, -0.1579]]],
[[[-0.0922, -0.0565, -0.1203, 0.0189, -0.1975],
[ 0.1806, -0.1699, 0.1544, 0.0333, -0.0649],
[ 0.1236, 0.0312, 0.1616, 0.0219, -0.0631],
[ 0.0537, -0.0542, 0.0842, 0.1786, 0.1156],
[-0.0874, 0.1155, 0.0358, 0.1016, -0.1219]]],
[[[-0.1980, -0.0773, -0.1534, 0.1641, 0.0576],
[ 0.0828, 0.0633, -0.0035, 0.1565, -0.1421],
[ 0.0126, -0.1365, 0.0617, -0.0689, 0.0613],
[-0.0417, 0.1659, -0.1185, -0.1193, -0.1193],
[ 0.1799, 0.0667, 0.1925, -0.1651, -0.1984]]],
[[[-0.1565, -0.1345, 0.0810, 0.0716, 0.1662],
[-0.1033, -0.1363, 0.1061, -0.0808, 0.1214],
[-0.0475, 0.1144, -0.1554, -0.1009, 0.0610],
[ 0.0423, -0.0510, 0.1192, 0.1360, -0.1450],
[-0.1068, 0.1831, -0.0675, -0.0709, -0.1935]]],
[[[-0.1145, 0.0500, -0.0264, -0.1452, 0.0047],
[-0.1366, -0.1697, -0.1101, -0.1750, -0.1273],
[ 0.1999, 0.0378, 0.0616, -0.1865, -0.1314],
[-0.0666, 0.0313, -0.1760, -0.0862, -0.1197],
[ 0.0006, -0.0744, -0.0139, -0.1355, -0.1373]]],
[[[-0.1167, -0.0685, -0.1579, 0.1677, -0.0397],
[ 0.1721, 0.0623, -0.1694, 0.1384, -0.0550],
[-0.0767, -0.1660, -0.1988, 0.0572, -0.0437],
[ 0.0779, -0.1641, 0.1485, -0.1468, -0.0345],
[ 0.0418