该文章是ICCV 2017的一篇模型压缩论文,提出了一个针对BN层的剪枝方法,利用BN层的权重(即缩放系数)来评估输入通道的重要程度(score),然后对score对于阈值的通道进行过滤,之后在连接成剪枝后的网络时,已经过滤的通道的神经元就不参与连接。
首先讨论channel-wise稀疏化的优势和遇到的挑战,然后介绍如何利用BN的缩放系数来高效的鉴别和剪枝不重要的通道。
稀疏化可以有不同的级别,即weight-level、kernel-level、layer-level。通道稀疏化需要将和一个通道有关联的所有输入和输出的连接剪掉,但是对于已经训练好的模型来说,不太可能做到这一点。
论文的主要思想:对于每个通道都引入一个缩放因子 γ \gamma γ,然后和通道的输出相乘。在训练过程中,联合训练网络权重和这些缩放因子,最后将小的缩放因子的通道直接移除,微调retrain剪枝后的网络。目标函数定义为:
L = ∑ ( x , y ) l ( f ( x , W ) , y ) + λ ∑ γ ∈ Γ g ( γ ) L = \sum_{(x, y)}l(f(x, W), y) + \lambda\sum_{\gamma\in\Gamma}g(\gamma) L=∑(x,y)l(f(x,W),y)+λ∑γ∈Γg(γ)
其中 ( x , y ) (x,y) (x,y)表示训练数据和标签, W W W表示网络的训练参数,第一项是CNN的训练损失函数。 g ( . ) g(.) g(.)是在缩放因子上的乘法项, λ \lambda λ是两项的平衡因子。论文的实现过程中选择 g ( s ) = ∣ s ∣ g(s)=|s| g(s)=∣s∣,即 L 1 L1 L1正则化。使用次梯度下降法作为不平滑的 L 1 L1 L1惩罚项的优化方法。也可以使用平滑的 L 1 L1 L1正则项,避免使用次梯度。
通道剪枝需要见到所有与这个通道相关的输入和输出连接信息,然后获得一个比较小的网络,不需要特殊的计算软硬件。缩放因子的作用是通道选择,将缩放因子的正则项加入到损失函数中,使得网路可以自动鉴别不重要的网络,然后移除掉,几乎不损失精度。
论文直接将BN层的 γ \gamma γ缩放系数作为通道裁剪的缩放因子,这样做在于没有给网络带来额外的开销。
在损失函数中引入缩放因子正则化之后,训练出来的模型许多缩放因子都会趋近于0。通道剪枝对应就是直接裁剪掉这个特征图对应的卷积核。
如何确定缩放因子阈值?通常缩放时会设置一个比例,假设为70%即裁剪叫整个网络中70%的通道,这个时候会对所有通道的缩放因子进行排序,然后得到70%对应的缩放因子大小作为阈值。
在网路剪枝中有两个关键的超参数,一个是剪枝的百分比和稀疏正则化系数 λ \lambda λ,它们对于模型剪枝的影响如下:
#--------------------------------------#
# 是否进行通道稀疏正则化训练
#--------------------------------------#
parser.add_argument('--sparsity-regularization', '-sr', dest='sr', action='store_true',
help='train with channel sparsity regularization')
#--------------------------------------#
# 设置通道稀疏正则化因子大小
#--------------------------------------#
parser.add_argument('--s', type=float, default=0.0001,
help='scale sparse rate (default: 0.0001)')
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
pred = output.data.max(1, keepdim=True)[1]
loss.backward()
if args.sr:
updateBN()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data.item()))
# 添加次梯度衰减到稀疏性惩罚项上
# additional subgradient descent on the sparsity-induced penalty term
# 其中args.s为通道稀疏化正则项系数
def updateBN():
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.weight.grad.data.add_(args.s*torch.sign(m.weight.data)) # L1
#-------------------------#
# 获取整个网络包含多少通道
#-------------------------#
total = 0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
total += m.weight.data.shape[0]
#----------------------------#
# 保存每个通道对应的weights值
#----------------------------#
bn = torch.zeros(total)
index = 0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
size = m.weight.data.shape[0]
bn[index:(index+size)] = m.weight.data.abs().clone()
index += size
#----------------------------------#
# 对所有通道的weights值进行排序
#----------------------------------#
y, i = torch.sort(bn) # 从小到大排序, 返回对应的值和序号
#------------------------------------#
# 根据裁剪百分比得到对应的序号以及阈值
#------------------------------------#
thre_index = int(total * args.percent)
thre = y[thre_index]
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
if isinstance(m, nn.BatchNorm2d):
weight_copy = m.weight.data.abs().clone()
mask = weight_copy.gt(thre).float().cuda()
#---------------------------------------#
# mask.shape[0]为原始的通道数
# torch.sum(mask)为裁剪后的通道数
# pruned为需要裁剪通道的总数
#---------------------------------------#
pruned = pruned + mask.shape[0] - torch.sum(mask)
#---------------------------------------------------#
# mask的值为0或1
# 权重和偏差乘以0之后相当于裁剪这个通道以及对应的连接
#---------------------------------------------------#
m.weight.data.mul_(mask)
m.bias.data.mul_(mask)
#-------------------------------------------------#
# cfg中包含裁剪后每一层的通道数
# cfg_mask中描述了每一层通道的mask
#-------------------------------------------------#
cfg.append(int(torch.sum(mask)))
cfg_mask.append(mask.clone())
print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))
elif isinstance(m, nn.MaxPool2d):
cfg.append('M')
layer_id_in_cfg = 0
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
for [m0, m1] in zip(model.modules(), newmodel.modules()):
if isinstance(m0, nn.BatchNorm2d):
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
if idx1.size == 1:
idx1 = np.resize(idx1,(1,))
#-------------------------------------#
# 从原模型中选择对应的权重信息
#-------------------------------------#
m1.weight.data = m0.weight.data[idx1.tolist()].clone()
m1.bias.data = m0.bias.data[idx1.tolist()].clone()
m1.running_mean = m0.running_mean[idx1.tolist()].clone()
m1.running_var = m0.running_var[idx1.tolist()].clone()
layer_id_in_cfg += 1
start_mask = end_mask.clone()
if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC
end_mask = cfg_mask[layer_id_in_cfg]
elif isinstance(m0, nn.Conv2d):
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
if idx1.size == 1:
idx1 = np.resize(idx1, (1,))
w1 = m0.weight.data[:, idx0.tolist(), :, :].clone() # C_out,C_int, K, K
w1 = w1[idx1.tolist(), :, :, :].clone()
m1.weight.data = w1.clone()
elif isinstance(m0, nn.Linear):
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
m1.weight.data = m0.weight.data[:, idx0].clone() # C_out, C_in
m1.bias.data = m0.bias.data.clone()
resnet网络结构进行裁剪训练,与vgg的区别在于,在每一个block的bn层后面加入了channel selection层,具体如下:
class channel_selection(nn.Module):
"""
Select channels from the output of BatchNorm2d layer. It should be put directly after BatchNorm2d layer.
The output shape of this layer is determined by the number of 1 in `self.indexes`.
"""
def __init__(self, num_channels):
"""
Initialize the `indexes` with all one vector with the length same as the number of channels.
During pruning, the places in `indexes` which correpond to the channels to be pruned will be set to 0.
"""
super(channel_selection, self).__init__()
self.indexes = nn.Parameter(torch.ones(num_channels))
def forward(self, input_tensor):
"""
Parameter
---------
input_tensor: (N,C,H,W). It should be the output of BatchNorm2d layer.
"""
selected_index = np.squeeze(np.argwhere(self.indexes.data.cpu().numpy()))
if selected_index.size == 1:
selected_index = np.resize(selected_index, (1,))
output = input_tensor[:, selected_index, :, :]
return output
resnet的basic block网络结构:
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, cfg, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.bn1 = nn.BatchNorm2d(inplanes)
self.select = channel_selection(inplanes)
self.conv1 = nn.Conv2d(cfg[0], cfg[1], kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(cfg[1])
self.conv2 = nn.Conv2d(cfg[1], cfg[2], kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(cfg[2])
self.conv3 = nn.Conv2d(cfg[2], planes * 4, kernel_size=1, bias=False)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.bn1(x)
out = self.select(out)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
return out
训练过程于上面VGG类似,主要区别在于如何裁剪后的模型的权重赋值过程,具体实现如下:
old_modules = list(model.modules())
new_modules = list(newmodel.modules())
layer_id_in_cfg = 0
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
conv_count = 0
for layer_id in range(len(old_modules)):
m0 = old_modules[layer_id]
m1 = new_modules[layer_id]
if isinstance(m0, nn.BatchNorm2d):
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
if idx1.size == 1:
idx1 = np.resize(idx1,(1,))
#---------------------------------------------------------#
# 如果BN层后面是channel selection层则不对BN层进行裁剪
#---------------------------------------------------------#
if isinstance(old_modules[layer_id + 1], channel_selection):
# If the next layer is the channel selection layer, then the current batchnorm 2d layer won't be pruned.
m1.weight.data = m0.weight.data.clone()
m1.bias.data = m0.bias.data.clone()
m1.running_mean = m0.running_mean.clone()
m1.running_var = m0.running_var.clone()
# We need to set the channel selection layer.
m2 = new_modules[layer_id + 1]
m2.indexes.data.zero_()
m2.indexes.data[idx1.tolist()] = 1.0
layer_id_in_cfg += 1
start_mask = end_mask.clone()
if layer_id_in_cfg < len(cfg_mask):
end_mask = cfg_mask[layer_id_in_cfg]
else:
m1.weight.data = m0.weight.data[idx1.tolist()].clone()
m1.bias.data = m0.bias.data[idx1.tolist()].clone()
m1.running_mean = m0.running_mean[idx1.tolist()].clone()
m1.running_var = m0.running_var[idx1.tolist()].clone()
layer_id_in_cfg += 1
start_mask = end_mask.clone()
if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC
end_mask = cfg_mask[layer_id_in_cfg]
elif isinstance(m0, nn.Conv2d):
#---------------------------------#
# 对于输入卷积不做任何修改
#--------------------------------#
if conv_count == 0:
m1.weight.data = m0.weight.data.clone()
conv_count += 1
continue
if isinstance(old_modules[layer_id-1], channel_selection) or isinstance(old_modules[layer_id-1], nn.BatchNorm2d):
# This convers the convolutions in the residual block.
# The convolutions are either after the channel selection layer or after the batch normalization layer.
#------------------------------------------------#
# 对于res block中的卷积层进行处理
#------------------------------------------------#
conv_count += 1
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
if idx1.size == 1:
idx1 = np.resize(idx1, (1,))
w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
# If the current convolution is not the last convolution in the residual block, then we can change the
# number of output channels. Currently we use `conv_count` to detect whether it is such convolution.
#---------------------------------------------#
# 如果现在的卷积是最后block中的最后一个
# 则不对输出通道进行裁剪
# 通过conv_count值来判断是否为最后一个
#---------------------------------------------#
if conv_count % 3 != 1:
w1 = w1[idx1.tolist(), :, :, :].clone()
m1.weight.data = w1.clone()
continue
# We need to consider the case where there are downsampling convolutions.
# For these convolutions, we just copy the weights.
m1.weight.data = m0.weight.data.clone()
elif isinstance(m0, nn.Linear):
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
m1.weight.data = m0.weight.data[:, idx0].clone()
m1.bias.data = m0.bias.data.clone()
对于输入的第一个卷积层以及basicBlock,模型裁剪过程示意图如下:
本文介绍了一种模型压缩中通道剪枝方法,该方法通过通过BN层的权重衡量通道的重要性,并在训练时稀疏化BN层的权重,通过设置的裁剪百分比得到对应阈值,对模型进行裁剪。后续分别介绍了VGG和ResNet进行裁剪的具体实现。
敏感度裁剪指的是通过各个层的敏感度分析来确定各个卷积层的裁剪率,需要和其他裁剪方法配置使用。
该策略通过同级filters两两之间的几何距离来评估单个卷积内的filters的重要性。直观理解,离其他filters平均距离越远的filters越重要。
该策略使用L1-norm统计量来表示一个卷积层内各个filters的重要性,L1-norm越大的filter越重要。
该策略使用L2-norm统计量来表示一个卷积层内各个filters的重要性,L2-norm越大的filter越重要。
该策略根据卷积之后BN层的权重来评估当前卷积内各个filter的重要性。scale越大,对应的filter越重要。