昨天讲了一篇ICLR 2017《Pruning Filters for Efficient ConvNets》 ,相信大家对模型剪枝有一定的了解了。今天我就剪一个简单的网络,体会一下模型剪枝的魅力。本文的代码均放在我的github工程,我是克隆了一个原始的pytorch模型压缩工程,然后我最近会公开一些在这个基础上新增的自测结果,一些经典的网络压缩benchmark,一些有趣的实验。欢迎关注,github地址见文后。最后申明一下,本人处于初学阶段,肯定了解的知识很浅并且会犯很多错误,有错误之处欢迎大家指出并和我交流讨论。
https://github.com/BBuf/model-compression
python 3.6.2
torch == 1.1.0
cuda 10.0
torchvison == 0.3.0
numpy
Net(
(tnn_bin): Sequential(
(0): Conv2d(3, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): FP_Conv2d(
(conv): Conv2d(192, 160, kernel_size=(1, 1), stride=(1, 1))
(bn): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): FP_Conv2d(
(conv): Conv2d(160, 96, kernel_size=(1, 1), stride=(1, 1))
(bn): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(4): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(5): FP_Conv2d(
(conv): Conv2d(96, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(bn): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(6): FP_Conv2d(
(conv): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1))
(bn): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(7): FP_Conv2d(
(conv): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1))
(bn): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(8): AvgPool2d(kernel_size=3, stride=2, padding=1)
(9): FP_Conv2d(
(conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(10): FP_Conv2d(
(conv): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1))
(bn): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(11): Conv2d(192, 10, kernel_size=(1, 1), stride=(1, 1))
(12): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(13): ReLU(inplace=True)
(14): AvgPool2d(kernel_size=8, stride=1, padding=0)
)
)
剪枝代码在prune/normal_regular_prune.py
中。通道剪枝的方法多种多样,这个工程所用的剪枝方法是统计每个卷积层后面接的BN层的weight
的绝对值,也就是BN层的gamma
参数。BN层的公式可以表示为:
y = x − m e a n ( x ) V a r ( x ) + e p s ∗ g a m m a + b e t a y=\frac{x-mean(x)}{\sqrt{Var(x)}+eps}*gamma+beta y=Var(x)+epsx−mean(x)∗gamma+beta
那么beta
就是BN层的bias
参数,剪枝的时候将BN层的每个缩放系数即scale
当成每一个通道的重要程度即可。然后,根据我们预先设置的剪枝比例percent
和网络中所有BN层的weight
参数组成的数组确定剪枝的权重阈值thre_0
。有了这个阈值就可以自行预剪枝和剪枝操作了。
首先确定剪枝的全局阈值,然后根据阈值得到剪枝后的网络每层的通道数cfg_mask
,这个cfg_mask
就可以确定我们剪枝后的模型的结构了,注意这个过程只是确定每一层那一些索引的通道要被剪枝掉并获得cfg_mask
,还没有真正的执行剪枝操作。我给代码加了部分注释,应该不难懂。
# 确定剪枝的全局阈值
bn = torch.zeros(total)
index = 0
i = 0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
if i < layers - 1:
i += 1
size = m.weight.data.shape[0]
bn[index:(index+size)] = m.weight.data.abs().clone()
index += size
# 按照权值大小排序
y, j = torch.sort(bn)
thre_index = int(total * args.percent)
if thre_index == total:
thre_index = total - 1
# 确定要剪枝的阈值
thre_0 = y[thre_index]
#********************************预剪枝*********************************
pruned = 0
cfg_0 = []
cfg = []
cfg_mask = []
i = 0
for k, m in enumerate(model.modules()):
if isinstance(m, nn.BatchNorm2d):
if i < layers - 1:
i += 1
weight_copy = m.weight.data.clone()
# 要保留的通道
mask = weight_copy.abs().gt(thre_0).float()
remain_channels = torch.sum(mask)
# 如果全部剪掉的话就提示应该调小剪枝程度了
if remain_channels == 0:
print('\r\n!please turn down the prune_ratio!\r\n')
remain_channels = 1
mask[int(torch.argmax(weight_copy))]=1
# ******************规整剪枝******************
v = 0
n = 1
if remain_channels % base_number != 0:
if remain_channels > base_number:
while v < remain_channels:
n += 1
v = base_number * n
if remain_channels - (v - base_number) < v - remain_channels:
remain_channels = v - base_number
else:
remain_channels = v
if remain_channels > m.weight.data.size()[0]:
remain_channels = m.weight.data.size()[0]
remain_channels = torch.tensor(remain_channels)
y, j = torch.sort(weight_copy.abs())
thre_1 = y[-remain_channels]
mask = weight_copy.abs().ge(thre_1).float()
# 剪枝掉的通道数个数
pruned = pruned + mask.shape[0] - torch.sum(mask)
m.weight.data.mul_(mask)
m.bias.data.mul_(mask)
cfg_0.append(mask.shape[0])
cfg.append(int(remain_channels))
cfg_mask.append(mask.clone())
print('layer_index: {:d} \t total_channel: {:d} \t remaining_channel: {:d} \t pruned_ratio: {:f}'.
format(k, mask.shape[0], int(torch.sum(mask)), (mask.shape[0] - torch.sum(mask)) / mask.shape[0]))
pruned_ratio = float(pruned/total)
print('\r\n!预剪枝完成!')
print('total_pruned_ratio: ', pruned_ratio)
没什么好说的,看一下我的代码注释好啦。
#********************************预剪枝后model测试*********************************
def test():
# 加载测试数据
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(root = args.data, train=False, transform=transforms.Compose([
transforms.ToTensor(),
# 对R, G,B通道应该减的均值
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
batch_size = 64, shuffle=False, num_workers=1)
model.eval()
correct = 0
for data, target in test_loader:
if not args.cpu:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
output = model(data)
pred = output.data.max(1, keepdim=True)[1]
# 记录类别预测正确的个数
correct += pred.eq(target.data.view_as(pred)).cpu().sum()
# 计算准确率
acc = 100. * float(correct) / len(test_loader.dataset)
print('Accuracy: {:.2f}%\n'.format(acc))
return
print('************预剪枝模型测试************')
if not args.cpu:
model.cuda()
test()
在预剪枝之后我们获得了每一个特征图需要剪掉哪些通道数的索引列表,接下来我们就可以按照这个列表执行剪枝操作了。注意一下,在预剪枝阶段是通过BN层的scale
参数获取的需要剪枝的通道索引,在剪枝阶段不仅仅需要剪掉BN层的对应通道,还要剪掉BN层前的卷积层的对应通道。剪枝的完整代码如下:
#********************************剪枝*********************************
# 定义新模型,结构和原始模型一样,但通道数变了
newmodel = nin.Net(cfg)
if not args.cpu:
newmodel.cuda()
layer_id_in_cfg = 0
# 定义原始模型和新模型的每一层保留通道索引的mask
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
i = 0
for [m0, m1] in zip(model.modules(), newmodel.modules()):
# 对BN层和ConV层都要裁枝
if isinstance(m0, nn.BatchNorm2d):
if i < layers - 1:
i += 1
# np.squeeze 从数组的形状中删除单维度条目,即把shape中为1的维度去掉
# np.argwhere(a) 返回非0的数组元组的索引,其中a是要索引数组的条件。
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
# 如果维度是1,那么就新增一维,这是为了和BN层的weight的维度匹配
if idx1.size == 1:
idx1 = np.resize(idx1, (1,))
m1.weight.data = m0.weight.data[idx1].clone()
m1.bias.data = m0.bias.data[idx1].clone()
m1.running_mean = m0.running_mean[idx1].clone()
m1.running_var = m0.running_var[idx1].clone()
layer_id_in_cfg += 1
# 注意start_mask在end_mask的前一层,这个会在裁剪Conv2d的时候用到
start_mask = end_mask.clone()
if layer_id_in_cfg < len(cfg_mask):
end_mask = cfg_mask[layer_id_in_cfg]
else:
# 如果到不需要没有裁枝的BN层,就直接赋值
m1.weight.data = m0.weight.data.clone()
m1.bias.data = m0.bias.data.clone()
m1.running_mean = m0.running_mean.clone()
m1.running_var = m0.running_var.clone()
elif isinstance(m0, nn.Conv2d):
if i < layers - 1:
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
if idx1.size == 1:
idx1 = np.resize(idx1, (1,))
# 注意卷积核Tensor维度为[n, c, w, h],两个卷积层连接,下一层的输入维度n'就等于当前层的c
w = m0.weight.data[:, idx0, :, :].clone()
m1.weight.data = w[idx1, :, :, :].clone()
m1.bias.data = m0.bias.data[idx1].clone()
else:
# 不需要裁枝的卷积层直接赋值
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
m1.weight.data = m0.weight.data[:, idx0, :, :].clone()
m1.bias.data = m0.bias.data.clone()
elif isinstance(m0, nn.Linear):
# 如果是线性层直接赋值
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
m1.weight.data = m0.weight.data[:, idx0].clone()
执行python main.py --refine models_save/nin_prune.pth
进行retrain和测试。
精度,GFLOPs,ParaM,Size对比如下图。网络在CIFAR10数据集上训练了50个Epoch,在剪枝后Retrain的时候只Retrain了10个Epoch。
剪枝前和剪枝后的网络结构详细结构和需要注意的一些细节如下图:
详细代码可以到我的工程中查看。
工程地址:https://github.com/BBuf/model-compression
这是深度学习算法优化系列的第二篇文章,之后会陆续学习和更新这个系列,不介意可以给公众号点个关注哦。
欢迎关注我的微信公众号GiantPandaCV,期待和你一起交流机器学习,深度学习,图像算法,优化技术,比赛及日常生活等。