手写AI推出的全新模型剪枝与重参课程。记录下个人学习笔记,仅供自己参考。
本次课程主要讲解YOLOv8剪枝。
课程大纲可看下面的思维导图
YOLOV8剪枝的流程如下:
结论:在VOC2007上使用yolov8s模型进行的实验显示,预训练和约束训练在迭代50个epoch后达到了相同的mAP(:0.5)值,约为0.77。剪枝后,微调阶段需要65个epoch才能达到相同的mAP50。修建后的ONNX模型大小从43M减少到36M。
注意:我们需要将网络结构和网络权重区分开来,YOLOv8的网络结构来自yaml文件,如果我们进行剪枝后保存的权重文件的结构其实是和原始的yaml文件不符合的,需要对yaml文件进行修改满足我们的要求。
步骤如下:
# FILE: ultralytics/yolo/engine/trainer.py
...
def check_amp(model):
# Avoid using mixed precision to affect finetune
return False # <============== modified(修改部分)
device = next(model.parameters()).device # get model device
if device.type in ('cpu', 'mps'):
return False # AMP only used on CUDA devices
def amp_allclose(m, im):
# All close FP32 vs AMP results
...
约束训练是为了筛选哪些channel比较重要,哪些channel没有那么重要,也就是我们上节课所说的稀疏训练
# FILE: ultralytics/yolo/engine/trainer.py
...
# Backward
self.scaler.scale(self.loss).backward()
# <============ added(新增)
l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
for k, m in self.model.named_modules():
if isinstance(m, nn.BatchNorm2d):
m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
if ni - last_opt_step >= self.accumulate:
self.optimizer_step()
last_opt_step = ni
...
注意1:在剪枝时,我们选择加载last.pt而非best.pt,因为由于迁移学习,模型的泛化性比较好,在第一个epoch时mAP值最大,但这并不是真实的,我们需要稳定下来的一个模型进行prune
注意2:我们在对Conv层进行剪枝时,我们只考虑1v1(如BottleNeck,C2f and SPPF)、1vm(如Backbone,Detect)的情形,并不考虑mv1的情形。
思考:Constrained Training的必要性?
约束训练可以使得模型更易于剪枝。在约束训练中,模型会学习到一些通道或者权重系数比较不重要的信息,而这些信息在剪枝过程中得到应用,从而达到模型压缩的效果。而如果直接进行剪枝操作,可能会出现一些问题,比如剪枝后的模型精度大幅下降、剪枝不均匀等。因此,在进行剪枝操作之前,通过稀疏训练的方式,可以更好地准确地确定哪些通道或者权重系数可以被剪掉,从而避免上述问题的发生。
for name, m in model.named_modules():
if isinstance(m, torch.nn.BatchNorm2d):
w = m.weight.abs().detach()
b = m.bias.abs().detach()
ws.append(w)
bs.append(b)
print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())
factor = 0.8
ws = torch.cat(ws)
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)
Top-Bottom Conv如下图所示:
TopConv剪枝的示例代码如下:
def prune_conv(conv1: Conv, conv2: Conv):
gamma = conv1.bn.weight.data.detach()
beta = conv1.bn.bias.data.detach()
keep_idxs = []
local_threshold = threshold
while len(keep_idxs) < 8:
keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
local_threshold = local_threshold * 0.5
n = len(keep_idxs)
print(n / len(gamma) * 100) # 打印我们保留了多少的channel
# prune
conv1.bn.weight.data = gamma[keep_idxs]
conv1.bn.bias.data = beta[keep_idxs]
conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
conv1.bn.num_features = n
conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
conv1.conv.out_channels = n
if conv1.conv.bias is not None:
conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]
# pattern to prune
# 1. prune all 1 vs 1 TB pattern e.g. bottleneck
for name, m in model.named_modules():
if isinstance(m, Bottleneck):
prune_conv(m.cv1, m.cv2)
注意:由于NVIDIA的硬件加速的原因,我们保留的channels应该大于等于8,我们可以通过设置local_threshold,尽量小点,让更多的channel保留下来。
BottomConv剪枝的示例代码如下:
def prune_conv(conv1: Conv, conv2: Conv):
...
if not isinstance(conv2, list):
conv2 = [conv2]
for item in conv2:
if item is not None:
if isinstance(item, Conv):
conv = item.conv
else:
conv = item
conv.in_channels = n
conv.weight.data = conv.weight.data[:, keep_idxs]
注意:BottomConv存在两种情形,一种是单个Conv,还有一种是Conv列表。TopConv需要考虑conv2d+bn+act,而BottomConv只需要考虑conv2d
Seq剪枝的示例代码如下:
def prune(m1, m2):
if isinstance(m1, C2f):
m1 = m1.cv2
if not isinstance(m2, list):
m2 = [m2]
for i, item in enumerate(m2):
if isinstance(item, C2f) or isinstance(item, SPPF):
m2[i] = item.cv1
prune_conv(m1, m2)
# 2. prune sequential
seq = model.model
for i in range(3, 9):
if i in [6, 4, 9]: continue
prune(seq[i], seq[i+1])
注意:我们不考虑1vm的情形,因此在4,6,9的module我们是不进行剪枝的,后续head进行Concat时是对4,6,9的module进行拼接的。考虑到前几个conv的特征提取的重要性,我们也不剪枝它们。(那感觉没剪几个module呀)
Detect-FPN剪枝的示例代码如下:
# 3. prune FPN related block
detect: Detect = seq[-1]
last_inputs = [seq[15], seq[18], seq[21]]
colasts = [seq[16], seq[19], None]
for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):
prune(last_input, [colast, cv2[0], cv3[0]])
prune(cv2[0], cv2[1])
prune(cv2[1], cv2[2])
prune(cv3[0], cv3[1])
prune(cv3[1], cv3[2])
for name, p in yolo.model.named_parameters():
p.requires_grad = True
注意:一定要设置所有参数为需要训练。因为加载后的model会给弄成False,导致报错
完整的示例代码如下:
from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect
# Load a model
yolo = YOLO("epoch-50-constrained_weights/last.pt") # build a new model from scratch
model = yolo.model
ws = []
bs = []
for name, m in model.named_modules():
if isinstance(m, torch.nn.BatchNorm2d):
w = m.weight.abs().detach()
b = m.bias.abs().detach()
ws.append(w)
bs.append(b)
print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())
# keep
factor = 0.8
ws = torch.cat(ws)
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)
def prune_conv(conv1: Conv, conv2: Conv):
gamma = conv1.bn.weight.data.detach()
beta = conv1.bn.bias.data.detach()
# if gamma.abs().min() > f or beta.abs().min() > 0.1:
# return
# idxs = torch.argsort(gamma.abs() * coeff + beta.abs(), descending=True)
keep_idxs = []
local_threshold = threshold
while len(keep_idxs) < 8:
keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
local_threshold = local_threshold * 0.5
n = len(keep_idxs)
# n = max(int(len(idxs) * 0.8), p)
print(n / len(gamma) * 100)
# scale = len(idxs) / n
conv1.bn.weight.data = gamma[keep_idxs]
conv1.bn.bias.data = beta[keep_idxs]
conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
conv1.bn.num_features = n
conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
conv1.conv.out_channels = n
if conv1.conv.bias is not None:
conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]
if not isinstance(conv2, list):
conv2 = [conv2]
for item in conv2:
if item is not None:
if isinstance(item, Conv):
conv = item.conv
else:
conv = item
conv.in_channels = n
conv.weight.data = conv.weight.data[:, keep_idxs]
def prune(m1, m2):
if isinstance(m1, C2f): # C2f as a top conv
m1 = m1.cv2
if not isinstance(m2, list): # m2 is just one module
m2 = [m2]
for i, item in enumerate(m2):
if isinstance(item, C2f) or isinstance(item, SPPF):
m2[i] = item.cv1
prune_conv(m1, m2)
for name, m in model.named_modules():
if isinstance(m, Bottleneck):
prune_conv(m.cv1, m.cv2)
seq = model.model
for i in range(3, 9):
if i in [6, 4, 9]: continue
prune(seq[i], seq[i+1])
detect:Detect = seq[-1]
last_inputs = [seq[15], seq[18], seq[21]]
colasts = [seq[16], seq[19], None]
for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):
prune(last_input, [colast, cv2[0], cv3[0]])
prune(cv2[0], cv2[1])
prune(cv2[1], cv2[2])
prune(cv3[0], cv3[1])
prune(cv3[1], cv3[2])
# ***step4,一定要设置所有参数为需要训练。因为加载后的model他会给弄成false。导致报错
# pipeline:
# 1. 为模型的BN增加L1约束,lambda用1e-2左右
# 2. 剪枝模型,比如用全局阈值
# 3. finetune,一定要注意,此时需要去掉L1约束。最终final的版本一定是去掉的
for name, p in yolo.model.named_parameters():
p.requires_grad = True
# 1. 不能剪枝的layer,其实可以不用约束
# 2. 对于低于全局阈值的,可以删掉整个module
# 3. keep channels,对于保留的channels,他应该能整除n才是最合适的,否则硬件加速比较差
# n怎么选,一般fp16时,n为8
# int8时,n为16
# cp.async.cg.shared
#
yolo.val()
# yolo.export(format="onnx")
# yolo.train(data="VOC.yaml", epochs=100)
print("done")
关于yolov8剪枝有以下几点值得注意:
Pipeline:
ultralytics/yolo/engine/trainer.py
中注释)yolo.model.named_parameters()
循环,需要设置p.requires_grad
为True
Future work:
n怎么选呢?一般fp16时,n为8;int8时,n为16
本次课程学习了YOLOv8的剪枝,主要是对前面剪枝课程的一个总结和实现吧,大体流程就是稀疏训练后进行剪枝最后微调,看着虽然简单,但实际细节把控还是非常多的,比如说哪些部分好剪,哪些部分不好剪,剪枝的过程中如何通过model获取想要prune的module等等,需要对YOLOv8整体网络结构和对ONNX模型的操作非常熟练,这还只是基础理论,实操部分的坑还没踩呢,在之后好好练习练习吧