对于深度学习来说,比较复杂的模型往往有着不错的识别效果,但是复杂的模型往往对算力要求也比较高,在一些对于实时性要求比较高或者算力比较小的应用场景中,这时复杂的模型往往不能很好达到预期效果,这时候就要进行模型的剪枝,提高模型的运算速度。 剪枝也就是将这个参数置为0,消除这些节点与后面的联系,从而降低运算量,本文主要基于对于模型剪枝的实战展开。
本文参考:深度学习之模型压缩(剪枝、量化)_深度学习模型压缩_CV算法恩仇录的博客-CSDN博客
目录
模型构造
必要的函数解释
module.named_parameters()
module.named_buffers()
model.state_dict().keys()
module._forward_pre_hooks
单层剪枝
连续单层剪枝
全局剪枝
自定义剪枝
本部分主要是先说明下面示例会用到的模型,就是我们大名鼎鼎的LeNet模型,当然其实其他模型也可以,只要是一个有着基本构造的网络都是可以的。
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# 1: 图像的输入通道(1是黑白图像), 6: 输出通道, 3x3: 卷积核的尺寸
self.conv1 = nn.Conv2d(1, 6, 3)
# self.conv1 = nn.Conv2d(2, 3, 3)
self.conv2 = nn.Conv2d(6, 16, 3)
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 是经历卷积操作后的图片尺寸
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, int(x.nelement() / x.shape[0]))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LeNet().to(device=device)
首先需要简单介绍一下下面可能会频繁出现的一些函数,如果不解释,可能就会看了很迷,我自己去搜索也没有知找到非常直观的解释,所以就按照我的理解线捋一捋。也可以先跳过去,之后遇到了再回来看。解释的都是基于卷积层,其他层类似。
在上面的的模型构造中已经说明module就是表示模型的第一层卷积层(当然任意一层都是可以的),针对于卷积层,这个函数得到的就是卷积核参数的情况以及一些信息。
这里我一开始比较迷惑的就是为什么Conv2d(1, 6, 3)的卷积层有许多3*3的卷积核。假设卷积层如上所示是Conv2d(1, 6, 3),代表输入1通道,输出6通道,卷积核大小3,那么其中的参数最小元素就是3*3的卷积核,因为要输出6个通道,那么每个输出通道都需要和1个输入通道卷积,得到一个通道输出,所以就有6*1个3*3的卷积核;假如输入是2通道,那么每个输出通道都需要和2个输入通道卷积,每个输入通道都需要一个卷积核,然后得到一个输出,这时候就会有6*2个3*3的卷积核。
这个函数表示掩码缓冲区。因为后面剪枝是针对于卷积核参数,需要标记哪些位置的参数是要被删除的,而这个缓冲区的数字是和卷积核参数一一对应的,要是这个参数被剪了,那么这个位置标记为0,否则是1,最后和参数矩阵相乘,被剪掉的位置参数就变为了0。
这个输出的是当前的状态列表,可能是需要用到的一些参数,在剪枝之前这个里面就是单独的参数,在剪枝之后就变成了参数备份和掩码矩阵,具体的作用也不是很懂。
这个参数是一个列表,里面就记录了堆某个层使用的算法记录,比如L1正则化这种。
下面介绍四种剪枝方式
1.单层剪枝(对于特定的卷积层或某进行剪枝)
2.连续单层剪枝(对多层进行单层剪枝)
3.全局剪枝(对全局进行剪枝)
4.自定义剪枝(自定义剪枝规则)
首先是对于某一个特定的层进行剪枝,利用prune.random_unstructured()函数,里面写入参数剪枝模型的特定层,比如卷积层,修剪对象也就是对权重weight还是偏置bias修剪,还有修剪比例,然后就会按照你的要求剪枝。假设修剪的是权重weight,缓冲区buffer里面放的就是掩码,标记了哪些位置的参数要被剪除,这些位置为0,否则为1。
module = model.conv1
print("---修剪前的状态字典")
print(model.state_dict().keys()) # 打印修剪前的状态字典,发现有weight
print("---修剪前的参数")
print(list(module.named_parameters()))
print("---修剪前的缓冲区")
print(list(module.named_buffers()))
prune.random_unstructured(module, name="weight", amount=0.3) # 对参数修剪
print("*" * 50)
print("---修剪后的状态字典")
print(model.state_dict().keys()) # 打印修剪前的状态字典,发现多出了 orig 和 mask
print("---修剪后的参数")
print(list(module.named_parameters())) # 实际上还没有变,下面会解释
print("---修剪后的缓冲区")
print(list(module.named_buffers())) # 这个就是掩码
print("---修剪算法")
print(module._forward_pre_hooks) # 这里里面存放的每个元素是一个算法
可以从状态列表里面看出修建前后的区别,就是weight变为了weight_orig和weight_mask,mask实际上标记了哪些位置是要剪除的,这个数据就是在缓冲区,所以一开始剪枝前是空而剪枝后有了数据。weight_orig就是做了个备份,还是原来的weight,到时候掩码和weight_orig两个相乘就是剪枝后的结果。可以从修剪后的参数中看出,实际上和修剪前是一样的,那么如何将参数变为修剪后,就要用到remove函数,remove也就类似于确定修剪的按钮,执行了之后就会把缓冲区的mask删掉,并且将要剪除的参数变为0,这个过程不可逆,执行之后,要是没有额外备份,那么参数就会被永久改变。(连在上面代码后面)
prune.remove(module, 'weight')
print("---执行remove后的参数")
print(list(module.named_parameters())) # 此时参数变化
连续单层剪枝其实类似于单层剪枝,唯一的不同就是上面对一个卷积层剪枝,现在我们可以利用一个循环,将所有卷积层和全连接层进行剪枝操作,单个循环内其实还是单层剪枝。
print(dict(model.named_buffers()).keys()) # 打印缓冲区
print(model.state_dict().keys()) # 打印初始模型的所有状态字典
print(dict(model.named_buffers()).keys()) # 打印初始模型的mask buffers张量字典名称,发现此时为空(因为还没剪枝)
for name, module in model.named_modules():
# 对模型中所有的卷积层执行l1_unstructured剪枝操作, 选取20%的参数剪枝
if isinstance(module, torch.nn.Conv2d): # 比较第一个是不是第二个表示的类,这里就是判断是不是卷积层
prune.l1_unstructured(module, name="weight", amount=0.2)
# 对模型中所有全连接层执行ln_structured剪枝操作, 选取40%的参数剪枝
elif isinstance(module, torch.nn.Linear):
prune.ln_structured(module, name="weight", amount=0.4, n=2, dim=0)
# 打印多参数模块剪枝后的mask buffers张量字典名称
print(dict(model.named_buffers()).keys()) # 打印缓冲区
print(model.state_dict().keys()) # 打印多参数模块剪枝后模型的所有状态字典名称
可以发现缓冲区内多出了每个层weight的mask,状态字典的weight也变成了weight_orig和mask。
上面两种都是对于特定层剪枝,而全局剪枝则是面向整个模型,在整个模型中剪除多少比例的参数,从而缩减模型。(此处代码基本是搬运的,感谢文首提到的大佬)
model = LeNet().to(device=device)
parameters_to_prune = (
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
(model.fc2, 'weight'),
(model.fc3, 'weight'))
prune.global_unstructured(parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2)
# 统计每个层被剪枝的数量百分比(也就是统计等于0的数字占总数的比例)
print(
"Sparsity in conv1.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv1.weight == 0))
/ float(model.conv1.weight.nelement())
))
print(
"Sparsity in conv2.weight: {:.2f}%".format(
100. * float(torch.sum(model.conv2.weight == 0))
/ float(model.conv2.weight.nelement())
))
print(
"Sparsity in fc1.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc1.weight == 0))
/ float(model.fc1.weight.nelement())
))
print(
"Sparsity in fc2.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc2.weight == 0))
/ float(model.fc2.weight.nelement())
))
print(
"Sparsity in fc3.weight: {:.2f}%".format(
100. * float(torch.sum(model.fc3.weight == 0))
/ float(model.fc3.weight.nelement())
))
print(
"Global sparsity: {:.2f}%".format(
100. * float(torch.sum(model.conv1.weight == 0)
+ torch.sum(model.conv2.weight == 0)
+ torch.sum(model.fc1.weight == 0)
+ torch.sum(model.fc2.weight == 0)
+ torch.sum(model.fc3.weight == 0))
/ float(model.conv1.weight.nelement()
+ model.conv2.weight.nelement()
+ model.fc1.weight.nelement()
+ model.fc2.weight.nelement()
+ model.fc3.weight.nelement())
))
运行之后就可以发现每一层都不同程度被剪除了参数,计算方式就是计算mask层中0所占的比例。
自定义剪枝的自定义主要是体现在剪枝方法上面,比如参数接近于0或者相对很小那么可能贡献很小,那么这时候就可以考虑剪除,对模型也不会造成很大影响。下面的示例采用隔位剪枝的方式,也就是隔一个剪一个,当然这是可以改掉的。(因为参考的是这么写的)
class myself_pruning_method(prune.BasePruningMethod):
PRUNING_TYPE = "unstructured"
# 内部实现compute_mask函数, 完成程序员自己定义的剪枝规则, 本质上就是如何去mask掉权重参数
def compute_mask(self, t, default_mask):
mask = default_mask.clone()
# 此处定义的规则是每隔一个参数就遮掩掉一个, 最终参与剪枝的参数量的50%被mask掉
# 当然可以自己定义
mask.view(-1)[::2] = 0
return mask
# 自定义剪枝方法的函数, 内部直接调用剪枝类的方法apply
def myself_unstructured_pruning(module, name):
myself_pruning_method.apply(module, name)
return module
# 下面开始剪枝
# 实例化模型类
model = LeNet().to(device=device)
start = time.time() # 计时
# 调用自定义剪枝方法的函数, 对model中的第三个全连接层fc3中的偏置bias执行自定义剪枝
myself_unstructured_pruning(model.fc3, name="bias")
# 剪枝成功的最大标志, 就是拥有了bias_mask参数
print(model.fc3.bias_mask)
# 打印一下自定义剪枝的耗时
duration = time.time() - start
print(duration * 1000, 'ms')