该文章是ICCV 2017的一篇模型压缩论文,提出了一个针对BN层的剪枝方法,利用BN层的权重(即缩放系数)来评估输入通道的重要程度(score),然后对score对于阈值的通道进行过滤,之后在连接成剪枝后的网络时,已经过滤的通道的神经元就不参与连接。
论文的主要思想:对于每个通道都引入一个缩放因子 γ \gamma γ,然后和通道的输出相乘。在训练过程中,联合训练网络权重和这些缩放因子,最后将小的缩放因子的通道直接移除,微调retrain剪枝后的网络。目标函数定义为:
L = ∑ ( x , y ) l ( f ( x , W ) , y ) + λ ∑ γ ∈ Γ g ( γ ) L = \sum_{(x, y)}l(f(x, W), y) + \lambda\sum_{\gamma\in\Gamma}g(\gamma) L=∑(x,y)l(f(x,W),y)+λ∑γ∈Γg(γ)
其中 ( x , y ) (x,y) (x,y)表示训练数据和标签, W W W表示网络的训练参数,第一项是CNN的训练损失函数。 g ( . ) g(.) g(.)是在缩放因子上的乘法项, λ \lambda λ是两项的平衡因子。论文的实现过程中选择 g ( s ) = ∣ s ∣ g(s)=|s| g(s)=∣s∣,即 L 1 L1 L1正则化。使用次梯度下降法作为不平滑的 L 1 L1 L1惩罚项的优化方法。也可以使用平滑的 L 1 L1 L1正则项,避免使用次梯度。
论文直接将BN层的 γ \gamma γ缩放系数作为通道裁剪的缩放因子,这样做在于没有给网络带来额外的开销。
在网路剪枝中有两个关键的超参数,一个是剪枝的百分比和稀疏正则化系数 λ \lambda λ,它们对于模型剪枝的影响如下:
# 是否进行通道稀疏正则化训练
parser.add_argument('--sparsity-regularization', '-sr', dest='sr', action='store_true',
help='train with channel sparsity regularization')
# 设置通道稀疏正则化因子大小
parser.add_argument('--s', type=float, default=0.0001,
help='scale sparse rate (default: 0.0001)')
def train(epoch):
for batch_idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
output = model(data)
loss = F.cross_entropy(output, target)
pred = output.data.max(1, keepdim=True)[1]
if args.sr:
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.1f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data.item()))
# 添加次梯度衰减到稀疏性惩罚项上
# additional subgradient descent on the sparsity-induced penalty term
# 其中args.s为通道稀疏化正则项系数
def updateBN():
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
m.weight.grad.data.add_(args.s*torch.sign(m.weight.data)) # L1
# 获取整个网络包含多少通道
total = 0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
total += m.weight.data.shape[0]
# 保存每个通道对应的weights值
bn = torch.zeros(total)
index = 0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
size = m.weight.data.shape[0]
bn[index:(index+size)] = m.weight.data.abs().clone()
index += size
# 对所有通道的weights值进行排序
y, i = torch.sort(bn) # 从小到大排序, 返回对应的值和序号
# 根据裁剪百分比得到对应的序号以及阈值
thre_index = int(total * args.percent)
thre = y[thre_index]
cfg = []
cfg_mask = []
for k, m in enumerate(model.modules()):
if isinstance(m, nn.BatchNorm2d):
weight_copy = m.weight.data.abs().clone()
mask = weight_copy.gt(thre).float().cuda()
# mask.shape[0]为原始的通道数
# torch.sum(mask)为裁剪后的通道数
# pruned为需要裁剪通道的总数
pruned = pruned + mask.shape[0] - torch.sum(mask)
# mask的值为0或1
# 权重和偏差乘以0之后相当于裁剪这个通道以及对应的连接
# cfg中包含裁剪后每一层的通道数
# cfg_mask中描述了每一层通道的mask
print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.format(k, mask.shape[0], int(torch.sum(mask))))
elif isinstance(m, nn.MaxPool2d):
layer_id_in_cfg = 0
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
for [m0, m1] in zip(model.modules(), newmodel.modules()):
if isinstance(m0, nn.BatchNorm2d):
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
if idx1.size == 1:
idx1 = np.resize(idx1,(1,))
# 从原模型中选择对应的权重信息
m1.weight.data = m0.weight.data[idx1.tolist()].clone()
m1.bias.data = m0.bias.data[idx1.tolist()].clone()
m1.running_mean = m0.running_mean[idx1.tolist()].clone()
m1.running_var = m0.running_var[idx1.tolist()].clone()
layer_id_in_cfg += 1
start_mask = end_mask.clone()
if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC
end_mask = cfg_mask[layer_id_in_cfg]
elif isinstance(m0, nn.Conv2d):
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
if idx1.size == 1:
idx1 = np.resize(idx1, (1,))
w1 = m0.weight.data[:, idx0.tolist(), :, :].clone() # C_out,C_int, K, K
w1 = w1[idx1.tolist(), :, :, :].clone()
m1.weight.data = w1.clone()
elif isinstance(m0, nn.Linear):
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
m1.weight.data = m0.weight.data[:, idx0].clone() # C_out, C_in
m1.bias.data = m0.bias.data.clone()
resnet网络结构进行裁剪训练,与vgg的区别在于,在每一个block的bn层后面加入了channel selection层,具体如下:
class channel_selection(nn.Module):
Select channels from the output of BatchNorm2d layer. It should be put directly after BatchNorm2d layer.
The output shape of this layer is determined by the number of 1 in `self.indexes`.
def __init__(self, num_channels):
Initialize the `indexes` with all one vector with the length same as the number of channels.
During pruning, the places in `indexes` which correpond to the channels to be pruned will be set to 0.
super(channel_selection, self).__init__()
self.indexes = nn.Parameter(torch.ones(num_channels))
def forward(self, input_tensor):
input_tensor: (N,C,H,W). It should be the output of BatchNorm2d layer.
selected_index = np.squeeze(np.argwhere(self.indexes.data.cpu().numpy()))
if selected_index.size == 1:
selected_index = np.resize(selected_index, (1,))
output = input_tensor[:, selected_index, :, :]
return output
resnet的basic block网络结构:
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, cfg, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.bn1 = nn.BatchNorm2d(inplanes)
self.select = channel_selection(inplanes)
self.conv1 = nn.Conv2d(cfg[0], cfg[1], kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(cfg[1])
self.conv2 = nn.Conv2d(cfg[1], cfg[2], kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(cfg[2])
self.conv3 = nn.Conv2d(cfg[2], planes * 4, kernel_size=1, bias=False)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.bn1(x)
out = self.select(out)
out = self.relu(out)
out = self.conv1(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out = self.relu(out)
out = self.conv3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
return out
old_modules = list(model.modules())
new_modules = list(newmodel.modules())
layer_id_in_cfg = 0
start_mask = torch.ones(3)
end_mask = cfg_mask[layer_id_in_cfg]
conv_count = 0
for layer_id in range(len(old_modules)):
m0 = old_modules[layer_id]
m1 = new_modules[layer_id]
if isinstance(m0, nn.BatchNorm2d):
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
if idx1.size == 1:
idx1 = np.resize(idx1,(1,))
# 如果BN层后面是channel selection层则不对BN层进行裁剪
if isinstance(old_modules[layer_id + 1], channel_selection):
# If the next layer is the channel selection layer, then the current batchnorm 2d layer won't be pruned.
m1.weight.data = m0.weight.data.clone()
m1.bias.data = m0.bias.data.clone()
m1.running_mean = m0.running_mean.clone()
m1.running_var = m0.running_var.clone()
# We need to set the channel selection layer.
m2 = new_modules[layer_id + 1]
m2.indexes.data[idx1.tolist()] = 1.0
layer_id_in_cfg += 1
start_mask = end_mask.clone()
if layer_id_in_cfg < len(cfg_mask):
end_mask = cfg_mask[layer_id_in_cfg]
m1.weight.data = m0.weight.data[idx1.tolist()].clone()
m1.bias.data = m0.bias.data[idx1.tolist()].clone()
m1.running_mean = m0.running_mean[idx1.tolist()].clone()
m1.running_var = m0.running_var[idx1.tolist()].clone()
layer_id_in_cfg += 1
start_mask = end_mask.clone()
if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC
end_mask = cfg_mask[layer_id_in_cfg]
elif isinstance(m0, nn.Conv2d):
# 对于输入卷积不做任何修改
if conv_count == 0:
m1.weight.data = m0.weight.data.clone()
conv_count += 1
if isinstance(old_modules[layer_id-1], channel_selection) or isinstance(old_modules[layer_id-1], nn.BatchNorm2d):
# This convers the convolutions in the residual block.
# The convolutions are either after the channel selection layer or after the batch normalization layer.
# 对于res block中的卷积层进行处理
conv_count += 1
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
if idx1.size == 1:
idx1 = np.resize(idx1, (1,))
w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
# If the current convolution is not the last convolution in the residual block, then we can change the
# number of output channels. Currently we use `conv_count` to detect whether it is such convolution.
# 如果现在的卷积是最后block中的最后一个
# 则不对输出通道进行裁剪
# 通过conv_count值来判断是否为最后一个
if conv_count % 3 != 1:
w1 = w1[idx1.tolist(), :, :, :].clone()
m1.weight.data = w1.clone()
# We need to consider the case where there are downsampling convolutions.
# For these convolutions, we just copy the weights.
m1.weight.data = m0.weight.data.clone()
elif isinstance(m0, nn.Linear):
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
m1.weight.data = m0.weight.data[:, idx0].clone()
m1.bias.data = m0.bias.data.clone()