手写AI推出的全新模型剪枝与重参课程。记录下个人学习笔记,仅供自己参考。
本次课程主要讲解基于VGG的模型剪枝的实战。
课程大纲可看下面的思维导图
本次剪枝实战是基于下面这篇论文去复现的,主要是实现对BN层的 γ \gamma γ进行剪枝操作,
我们先来考虑一个问题,深度学习模型里面的卷积层出来之后的特征有非常多,这里面会不会存在一些没有价值的特征及其相关的连接?又如何判断一个特征及其连接是否有价值?
先给出答案:在Batch Normalize层的缩放因子上施加L1正则化(这是上面这篇论文的核心思想,更多细节请自行阅读论文)
优点:
下面是论文中提出的用于BN层 γ \gamma γ参数稀疏训练的损失函数
L = ∑ ( x , y ) l ( f ( x , W ) , y ) + λ ∑ γ ∈ Γ g ( γ ) L = \sum_{(x,y)} l\Big(f(x, W), y\Big) + \lambda\sum_{\gamma \in \Gamma} g(\gamma) L=(x,y)∑l(f(x,W),y)+λγ∈Γ∑g(γ)
总的来说,这个损失函数的作用是在分类问题的基础上,加上对BN层 γ \gamma γ参数的L1正则化。在训练过程中,通过调整超参数 λ \lambda λ的大小,可以实现对BN层 γ \gamma γ参数的稀疏训练,即将一部分 γ \gamma γ参数的值变为0,从而达到减少模型复杂度、提高模型泛化性能的效果。
具体实现流程可见下图
首先初始化模型获得一个benchmark=>稀疏训练=>剪枝=>微调=>最终模型
我们对模型进行剪枝,主要针对有参数的层:Conv2d、BatchNorm2d、Linear,Pool2d的层只用来做下采样,没有可学习的参数,不用处理。下面是一些关于mask的一些说明
cfg和cfg_mask
gt()
(greater than) 方法得到 mask,小于 threshold 的置零Conv2d
BatchNorm2d
Linear
我们先来实现一个test()
函数,用于测试prune剪枝后模型的性能,示例代码如下:
import argparse
from utils import get_test_dataloader
import torch
def parse_opt():
# Prune setting
parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune')
parser.add_argument('--dataset', type=str, default='cifar10', help='training dataset (default: cifar10)')
parser.add_argument('--test-batch-size', type=int, default=256, metavar='N', help='input batch size for test (default: 256)')
parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
parser.add_argument('--depth', type=int, default=11, help='depth of the vgg')
parser.add_argument('--percent', type=float, default=0.5, help='scale sparse rate (default: 0.5)')
parser.add_argument('--model', default='', type=str, metavar='PATH', help='path to the model (default: none)')
parser.add_argument('--save', default='logs/', type=str, metavar='PATH', help='path to save pruned model (default: none)')
args = parser.parse_args()
return args
def test(model):
kwargs = {'num_workers' : 1, 'pin_memory' : True} if args.cuda else {}
test_loader = get_test_dataloader(batch_size=args.test_batch_size, **kwargs)
model.eval()
correct = 0
with torch.no_grad():
for data, target in test_loader:
if args.cuda:
data, target = data.cuda(), target.cuda()
output = model(data)
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).cpu().sum()
accuracy = 100. * correct / len(test_loader.dataset)
print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
correct, len(test_loader), accuracy
))
return accuracy / 100
if __name__ == "__main__":
args = parse_opt()
我们需要使用上节课的train.py来获得稀疏训练模型的权重,可在终端执行如下指令:
python .\train.py -sr --s 0.0001 --dataset cifar10 --arch vgg --depth 11 --epochs 10
接着我们对稀疏训练后的模型进行加载,因为需要对稀疏训练后BN层的一些参数进行统计,示例代码如下:
import os
import argparse
from models.vgg import VGG
from utils import get_test_dataloader
import torch
def parse_opt():
...
def test(model):
...
if __name__ == "__main__":
args = parse_opt()
args.cuda = not args.no_cuda and torch.cuda.is_available()
if not os.path.exists(args.save):
os.makedirs(args.save)
model = VGG(depth=args.depth)
if args.cuda:
model.cuda()
if args.model:
if os.path.isfile(args.model):
print("=> loading checkpoing '{}'".format(args.model))
checkpoint = torch.load(args.model)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}".format(
args.model, checkpoint['epoch'], best_prec1
))
else:
print("=> no checkpoing found at '{}'".format(args.model))
print(model)
现在我们对BN层进行prune,将其配置信息cfg和对应的cfg_mask保存下来,供后面使用,示例代码如下:
if __name__ == "__main__":
...
total = 0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
total += m.weight.data.shape[0]
bn = torch.zeros(total)
index = 0
# 获取所有BN层的 gamma 参数,存储在nn.BatchNorm2d.weight.data
# beta 参数存储在nn.BatchNorm2d.bias.data
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
size = m.weight.data.shape[0]
bn[index:(index+size)] = m.weight.data.abs().clone()
index += size
# 获取threshold
y, i = torch.sort(bn)
thre_index = int(total * args.percent)
thre = y[thre_index]
pruned = 0
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
if isinstance(m, nn.BatchNorm2d):
weight_copy = m.weight.data.abs().clone()
mask = weight_copy.gt(thre).float().cuda()
pruned = pruned + mask.shape[0] - torch.sum(mask)
m.weight.data.mul_(mask)
m.bias.data.mul_(mask)
cfg.append(int(torch.sum(mask)))
cfg_mask.append(mask.clone())
print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(
k, mask.shape[0], int(torch.sum(mask))
))
elif isinstance(m, nn.MaxPool2d):
cfg.append('M')
pruned_ratio = pruned / total
print("Pre-process Sucessful Pruned Ratio: {:.2f}%".format(pruned_ratio * 100.))
acc = test(model)
print(cfg)
其打印输出如下:
layer index: 3 total channel: 64 remaining channel: 63
layer index: 7 total channel: 128 remaining channel: 126
layer index: 11 total channel: 256 remaining channel: 227
layer index: 14 total channel: 256 remaining channel: 162
layer index: 18 total channel: 512 remaining channel: 180
layer index: 21 total channel: 512 remaining channel: 194
layer index: 25 total channel: 512 remaining channel: 191
layer index: 28 total channel: 512 remaining channel: 232
Pre-process Sucessful Pruned Ratio: 50.04%
Files already downloaded and verified
Test set: Accuracy: 1757/40 (17.6%)
[63, 'M', 126, 'M', 227, 162, 'M', 180, 194, 'M', 191, 232]
可以看到最终的配置信息与原来的有所不同,具体对比如下:(前面通道数没啥变化,后面通道数剪枝较多,精度下降严重)
# old
[64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512]
# new
[63, 'M', 126, 'M', 227, 162, 'M', 180, 194, 'M', 191, 232]
我们拿到cfg和cfg_mask后就可以对带参数的三个层即Conv2d、BatchNorm2d、Linear进行剪枝的操作了,
我们先通过cfg建立一个新模型,并存储其相关信息,示例代码如下:
if __name__ == "__main__":
newmodel = VGG(cfg=cfg)
if args.cuda:
newmodel.cuda()
num_parameters = sum([param.nelement() for param in newmodel.parameters()])
savepath = os.path.join(args.save, "prune.txt")
with open(savepath, 'w') as fp:
fp.write("Configuation: " + str(cfg) + "\n")
fp.write("Number of parameters: " + str(num_parameters) + "\n")
fp.write("Test accuracy: " + str(acc))
layer_id_in_cfg = 0
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
for [m0, m1] in zip(model.modules(), newmodel.modules()):
pass
说明:start_mask和end_mask => 对应于Conv+BN层的输入和输出
BN层剪枝的示例代码如下:
if __name__ == "__main__":
...
for [m0, m1] in zip(model.modules(), newmodel.modules()):
if isinstance(m0, nn.BatchNorm2d):
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
if idx1.size == 1:
idx1 = np.resize(idx1, (1,))
m1.weight.data = m0.weight.data[idx1.tolist()].clone()
m1.bias.data = m0.bias.data[idx1.tolist()].clone()
m1.running_mean = m0.running_mean[idx1.tolist()].clone()
m1.running_var = m0.running_var[idx1.tolist()].clone()
layer_id_in_cfg += 1
start_mask = end_mask.clone()
if layer_id_in_cfg < len(cfg_mask):
end_mask = cfg_mask[layer_id_in_cfg]
Conv层剪枝的示例代码如下:
if __name__ == "__main__":
...
for [m0, m1] in zip(model.modules(), newmodel.modules()):
...
elif isinstance(m0, nn.Conv2d):
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
print("In channels: {:d}, Out channels: {:d}".format(idx0.size, idx1.size))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
if idx1.size == 1:
idx1 = np.resize(idx1, (1,))
w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
w1 = w1[idx1.tolist(), :, :, :].clone()
m1.weight.data = w1.clone()
Linear层剪枝的示例代码如下:
if __name__ == "__main__":
...
for [m0, m1] in zip(model.modules(), newmodel.modules()):
...
elif isinstance(m0, nn.Linear):
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
m1.weight.data = m0.weight.data[:, idx0].clone()
m1.bias.data = m0.bias.data.clone()
完整的示例代码如下:
import os
import argparse
import numpy as np
import torch
import torch.nn as nn
from models.vgg import VGG
from utils import get_test_dataloader
def parse_opt():
# Prune settings
parser = argparse.ArgumentParser(description='PyTorch Slimming CIFAR prune')
parser.add_argument('--dataset', type=str, default='cifar100', help='training dataset (default: cifar10)')
parser.add_argument('--test-batch-size', type=int, default=256, metavar='N', help='input batch size for testing (default: 256)')
parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
parser.add_argument('--depth', type=int, default=19, help='depth of the vgg')
parser.add_argument('--percent', type=float, default=0.5, help='scale sparse rate (default: 0.5)')
parser.add_argument('--model', default='', type=str, metavar='PATH', help='path to the model (default: none)')
parser.add_argument('--save', default='logs/', type=str, metavar='PATH', help='path to save pruned model (default: none)')
args = parser.parse_args()
return args
# simple test model after Pre-processing prune (simple set BN scales to zeros)
# Define a function named test that takes a PyTorch model as input
def test(model):
# Set kwargs to num_workers=1 and pin_memory=True if args.cuda is True,
# otherwise kwargs is an empty dictionary
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
# Create a test data loader for the CIFAR10 dataset if args.dataset is 'cifar10'
if args.dataset == 'cifar10':
test_loader = get_test_dataloader(batch_size=args.test_batch_size, **kwargs)
else:
raise ValueError("No valid dataset is given.")
# Set the model to evaluation mode
model.eval()
# Initialize the number of correct predictions to 0
correct = 0
# Turn off gradient calculation during inference
with torch.no_grad():
# Loop through the test data
for data, target in test_loader:
# Move the data and target tensors to the GPU if args.cuda is True
if args.cuda:
data, target = data.cuda(), target.cuda()
# Compute the output of the model on the input data
output = model(data)
# Compute the predictions from the output using the argmax operation
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
# Compute the number of correct predictions and add it to the running total
correct += pred.eq(target.data.view_as(pred)).cpu().sum()
# Compute the test accuracy and print the result
accuracy = 100. * correct / len(test_loader.dataset)
print('\nTest set: Accuracy: {}/{} ({:.1f}%)\n'.format(
correct, len(test_loader.dataset), accuracy))
# Return the test accuracy as a float
return accuracy / 100.
if __name__ == '__main__':
# Parse command line arguments using the parse_opt() function
args = parse_opt()
# Check if CUDA is available and set args.cuda flag accordingly
args.cuda = not args.no_cuda and torch.cuda.is_available()
# Create the save directory if it does not exist
if not os.path.exists(args.save):
os.makedirs(args.save)
# Create a new VGG model with the specified depth
model = VGG(depth=args.depth)
# Move the model to the GPU if args.cuda is True
if args.cuda:
model.cuda()
# If args.model is not None,
# attempt to load a checkpoint from the specified file
if args.model:
if os.path.isfile(args.model):
print("=> loading checkpoint '{}'".format(args.model))
checkpoint = torch.load(args.model)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
.format(args.model, checkpoint['epoch'], best_prec1))
else:
print("=> no checkpoint found at '{}'".format(args.model))
# Print the model to the console
print(model)
# Initialize the total number of channels to 0
total = 0
# Loop through the model's modules and count the number of channels in each BatchNorm2d layer
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
total += m.weight.data.shape[0]
# Create a new tensor to store the absolute values of the weights of each BatchNorm2d layer
bn = torch.zeros(total)
# Initialize an index variable to 0
index = 0
# Loop through the model's modules again and
# store the absolute values of the weights of each BatchNorm2d layer in the bn tensor
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
size = m.weight.data.shape[0]
bn[index:(index+size)] = m.weight.data.abs().clone()
index += size
# Sort the bn tensor and compute the threshold value for pruning
y, i = torch.sort(bn)
thre_index = int(total * args.percent)
thre = y[thre_index]
# Initialize the number of pruned channels to 0 and
# create lists to store the new configuration and mask for each layer
pruned = 0
cfg = []
cfg_mask = []
# Loop through the model's modules a third time and
# prune each BatchNorm2d layer that falls below the threshold value
for k, m in enumerate(model.modules()):
if isinstance(m, nn.BatchNorm2d):
# Compute a mask indicating which weights to keep and which to prune
weight_copy = m.weight.data.abs().clone()
mask = weight_copy.gt(thre).float().cuda()
pruned = pruned + mask.shape[0] - torch.sum(mask)
# Apply the mask to the weight and bias tensors of the BatchNorm2d layer
m.weight.data.mul_(mask)
m.bias.data.mul_(mask)
# Record the new configuration and mask for this layer
cfg.append(int(torch.sum(mask)))
cfg_mask.append(mask.clone())
# Print information about the pruning for this layer
print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
format(k, mask.shape[0], int(torch.sum(mask))))
elif isinstance(m, nn.MaxPool2d):
# If the module is a MaxPool2d layer,
# record it as an 'M' in the configuration list
cfg.append('M')
# Compute the ratio of pruned channels to total channels
pruned_ratio = pruned/total
# Print a message indicating that the pre-processing was successful
print('Pre-processing Successful!')
# Evaluate the pruned model on the test set and
# store the accuracy in the acc variable
acc = test(model)
# ============================ Make real prune ============================
# Print the new configuration to the console
print(cfg)
# Initialize a new VGG model with the pruned configuration
newmodel = VGG(cfg=cfg)
# Move the new model to the GPU if available
if args.cuda:
newmodel.cuda()
# Compute the number of parameters in the new model
num_parameters = sum([param.nelement() for param in newmodel.parameters()])
# Save the configuration above, number of parameters, and test accuracy to a file
savepath = os.path.join(args.save, "prune.txt")
with open(savepath, "w") as fp:
fp.write("Configuration: \n"+str(cfg)+"\n")
fp.write("Number of parameters: "+str(num_parameters)+"\n")
fp.write("Test accuracy: "+str(acc))
# Initialize variables for the masks corresponding to the start and end of each pruned layer
layer_id_in_cfg = 0
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
# Loop through the modules of the original and new models
# Copy the weights and biases of each layer from the original model to the new model
# Applying the appropriate masks to the weights and biases of the pruned layers
for [m0, m1] in zip(model.modules(), newmodel.modules()):
# ============================ BatchNorm Layers ============================
# If the module is a BatchNorm2d layer,
# compute the indices of the non-zero weights and biases in the new model and
# copy them from the original model
if isinstance(m0, nn.BatchNorm2d):
# Compute the list of indices of the remaining channels in the current BatchNorm2d layer
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
# Resize the index list if it has only one element
if idx1.size == 1:
idx1 = np.resize(idx1,(1,))
# Compute the weight of the current layer
# by copying only the weights of the remaining channels using the index list
m1.weight.data = m0.weight.data[idx1.tolist()].clone()
# Compute the bias of the current layer
# by copying the bias values of the original layer and then cloned
m1.bias.data = m0.bias.data[idx1.tolist()].clone()
# Compute the running mean of the current layer by
# copying the mean values of the original layer and then cloned
m1.running_mean = m0.running_mean[idx1.tolist()].clone()
# Compute the running variance of the current layer by
# copying the variance values of the original layer and then cloned
m1.running_var = m0.running_var[idx1.tolist()].clone()
# Update the masks for the next pruned layer
layer_id_in_cfg += 1
start_mask = end_mask.clone()
if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC
end_mask = cfg_mask[layer_id_in_cfg]
# ============================ Conv2d Layers ============================
# If the module is a Conv2d layer,
# compute the indices of the non-zero weights in the input and output channels and
# copy them from the original model
elif isinstance(m0, nn.Conv2d):
# Get the indices of input and output channels that are not pruned for this convolutional layer,
# by converting the start and end masks from the previous and current layers into numpy arrays,
# finding the non-zero elements, and removing the extra dimensions
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
# Print the number of input and output channels that are not pruned
print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
# If either idx0 or idx1 has a size of 1,
# resize it to (1,) to avoid a broadcasting error.
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
if idx1.size == 1:
idx1 = np.resize(idx1, (1,))
# Extract the weight tensor for this layer from the original model (m0)
# by selecting the input and output channels that are not pruned,
# and clone it to create a new tensor (w1)
w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
w1 = w1[idx1.tolist(), :, :, :].clone()
m1.weight.data = w1.clone()
# ============================ Linear Layers ============================
# If the module is a Linear layer,
# compute the indices of the non-zero weights in the input channels and
# copy them from the original model
elif isinstance(m0, nn.Linear):
# Compute the list of indices of the remaining neurons/channels
# of the previous layer that connect to this current linear layer
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
# Resize the index list if it has only one element
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
# Compute the weight of the current layer
# by copying only the weights of the remaining channels of the previous layer
# using the index list
m1.weight.data = m0.weight.data[:, idx0].clone()
m1.bias.data = m0.bias.data.clone()
torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save, 'pruned.pth'))
print(newmodel)
model = newmodel
test(model)
本次课程完成了对VGG模型的剪枝训练,主要是复现论文中对BN层的 γ \gamma γ参数进行稀疏训练,得到对应的mask后对Conv2d、Batch Normalize以及Linear层进行剪枝,可以看到剪枝后的模型的参数量大大减少(71M=>9.6M),且预测准确率反而提高了(87.4%=>88.4%),而对YOLOv8采用该方法进行剪枝时,其精度会略微下降,但是其参数量会大大减少,具有可应用性,期待下基于YOLOv8的剪枝吧