网络部分包含4部分:
(1)PillarVFE
(2)PointPillarScatter
(3)BaseBEVBackbone
(4)AnchorHeadSingle
主要对BaseBEVBackbone部分剪枝,BaseBEVBackbone网络结构图如下:
具体如下:
(backbone_2d): BaseBEVBackbone(
(blocks): ModuleList(
(0): Sequential(
(0): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), bias=False)
(2): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(3): ReLU()
(4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(5): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(6): ReLU()
(7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(8): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(9): ReLU()
(10): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(11): BatchNorm2d(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(12): ReLU()
)
(1): Sequential(
(0): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
(1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), bias=False)
(2): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(3): ReLU()
(4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(5): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(6): ReLU()
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(8): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(9): ReLU()
(10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(11): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(12): ReLU()
(13): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(14): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(15): ReLU()
(16): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(17): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(18): ReLU()
)
(2): Sequential(
(0): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
(1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), bias=False)
(2): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(3): ReLU()
(4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(5): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(6): ReLU()
(7): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(8): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(9): ReLU()
(10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(11): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(12): ReLU()
(13): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(14): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(15): ReLU()
(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(17): BatchNorm2d(256, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(18): ReLU()
)
)
(deblocks): ModuleList(
(0): Sequential(
(0): ConvTranspose2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): ReLU()
)
(1): Sequential(
(0): ConvTranspose2d(128, 128, kernel_size=(2, 2), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): ReLU()
)
(2): Sequential(
(0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(4, 4), bias=False)
(1): BatchNorm2d(128, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(2): ReLU()
)
)
)
2.1稀疏训练
对BN层的参数进行诱导,让大部分参数趋于零,降低剪枝对模型精度的影响
loss.backward()
updateBN(model)
optimizer.step()
def updateBN(model):
s = 0.0001
for m in model.modules():
if isinstance(m, torch.nn.BatchNorm2d):
m.weight.grad.data.add_(s*torch.sign(m.weight.data)) # L1
2.2对稀疏训练后的模型剪枝-Network_Slimming
(1)根据剪枝率(percent)计算阈值
total = 0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
total += m.weight.data.shape[0]
bn = torch.zeros(total)
index = 0
for m in model.modules():
if isinstance(m, nn.BatchNorm2d):
size = m.weight.data.shape[0]
bn[index:(index+size)] = m.weight.data.abs().clone()
index += size
y, i = torch.sort(bn)
thre_index = int(total * 0.7)
thre = y[thre_index]
(2)生成cfg_index(通道剪枝个数索引列表)与cfg_mask
pruned = 0
cfg_index = []
cfg_mask = []
for k, m in enumerate(model.modules()):
if isinstance(m, nn.BatchNorm2d):
weight_copy = m.weight.data.abs().clone()
mask = weight_copy.cpu().gt(thre).float().cuda()
#pdb.set_trace()
pruned = pruned + mask.shape[0] - torch.sum(mask)
m.weight.data.mul_(mask)
m.bias.data.mul_(mask)
cfg_index.append(int(torch.sum(mask)))
cfg_mask.append(mask.clone())
print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
format(k, mask.shape[0], int(torch.sum(mask))))
elif isinstance(m, nn.MaxPool2d):
cfg_index.append('M')
(3)对不想剪枝的bn层,cfg_mask该bn层参数全部置1
例如:
cfg_mask[0][:]=1 ##对第一个bn不剪枝
注:1)应该有更好的办法,具体问题具体分析,现在只是实现了
2)有很多层不能剪枝,请注意
(4)根据cfg_index构建剪枝后模型框架
newmodel = build_network(model_cfg=cfg.MODEL, num_class=len(cfg.CLASS_NAMES), dataset=test_set, cfg_index=cfg_index)
newmodel = newmodel.to(device='cuda:0')
注意,此处的build_network需要改写,我主要是剪枝BaseBEVBackbone,所以将此模块的每个卷积层的输入输出尺寸与cfg_index对应,如下:
class BaseBEVBackbone(nn.Module):
def __init__(self, model_cfg, input_channels, cfg_index=None):
super().__init__()
self.model_cfg = model_cfg
if self.model_cfg.get('LAYER_NUMS', None) is not None:
assert len(self.model_cfg.LAYER_NUMS) == len(self.model_cfg.LAYER_STRIDES) == len(self.model_cfg.NUM_FILTERS)
layer_nums = self.model_cfg.LAYER_NUMS
layer_strides = self.model_cfg.LAYER_STRIDES
num_filters = self.model_cfg.NUM_FILTERS
else:
layer_nums = layer_strides = num_filters = []
if self.model_cfg.get('UPSAMPLE_STRIDES', None) is not None:
assert len(self.model_cfg.UPSAMPLE_STRIDES) == len(self.model_cfg.NUM_UPSAMPLE_FILTERS)
num_upsample_filters = self.model_cfg.NUM_UPSAMPLE_FILTERS
upsample_strides = self.model_cfg.UPSAMPLE_STRIDES
else:
upsample_strides = num_upsample_filters = []
num_levels = len(layer_nums)
c_in_list = [input_channels, *num_filters[:-1]]
self.blocks = nn.ModuleList()
self.deblocks = nn.ModuleList()
if cfg_index is None:
cfg_index = [64, 64, 64, 64, 128, 128, 128, 128, 128, 128, 256, 256, 256, 256, 256, 256, 128, 128, 128]
cfg=cfg_index
for idx in range(num_levels):
if idx == 0:
cur_layers = [
nn.ZeroPad2d(1),
nn.Conv2d(
64,64, kernel_size=3,
stride=layer_strides[idx], padding=0, bias=False
),
nn.BatchNorm2d(64, eps=1e-3, momentum=0.01),
nn.ReLU()
]
for k in range(3):
if k ==0:
cur_layers.extend([
nn.Conv2d(64, cfg[k+1], kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(cfg[k+1], eps=1e-3, momentum=0.01),
nn.ReLU()
])
if k ==1:
cur_layers.extend([
nn.Conv2d(cfg[k+0], cfg[k+1], kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(cfg[k+1], eps=1e-3, momentum=0.01),
nn.ReLU()
])
if k ==2:
cur_layers.extend([
nn.Conv2d(cfg[k+0], 64, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(64, eps=1e-3, momentum=0.01),
nn.ReLU()
])
elif idx ==1 :
cur_layers = [
nn.ZeroPad2d(1),
nn.Conv2d(
64, cfg[4], kernel_size=3,
stride=layer_strides[idx], padding=0, bias=False
),
nn.BatchNorm2d(cfg[4], eps=1e-3, momentum=0.01),
nn.ReLU()
]
for k in range(5):
if k ==4:
cur_layers.extend([
nn.Conv2d(cfg[k+4], 128, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(128, eps=1e-3, momentum=0.01),
nn.ReLU()
])
else:
cur_layers.extend([
nn.Conv2d(cfg[k+4], cfg[k+5], kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(cfg[k+5], eps=1e-3, momentum=0.01),
nn.ReLU()
])
elif idx ==2 :
cur_layers = [
nn.ZeroPad2d(1),
nn.Conv2d(
128, cfg[10], kernel_size=3,
stride=layer_strides[idx], padding=0, bias=False
),
nn.BatchNorm2d(cfg[10], eps=1e-3, momentum=0.01),
nn.ReLU()
]
for k in range(5):
if k==4:
cur_layers.extend([
nn.Conv2d(cfg[k+10], 256, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(256, eps=1e-3, momentum=0.01),
nn.ReLU()
])
else:
cur_layers.extend([
nn.Conv2d(cfg[k+10], cfg[k+11], kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(cfg[k+11], eps=1e-3, momentum=0.01),
nn.ReLU()
])
self.blocks.append(nn.Sequential(*cur_layers))
if len(upsample_strides) > 0:
stride = upsample_strides[idx]
if stride >= 1:
self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d(
num_filters[idx], num_upsample_filters[idx],
upsample_strides[idx],
stride=upsample_strides[idx], bias=False
),
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
))
else:
stride = np.round(1 / stride).astype(np.int)
self.deblocks.append(nn.Sequential(
nn.Conv2d(
num_filters[idx], num_upsample_filters[idx],
stride,
stride=stride, bias=False
),
nn.BatchNorm2d(num_upsample_filters[idx], eps=1e-3, momentum=0.01),
nn.ReLU()
))
c_in = sum(num_upsample_filters)
if len(upsample_strides) > num_levels:
self.deblocks.append(nn.Sequential(
nn.ConvTranspose2d(c_in, c_in, upsample_strides[-1], stride=upsample_strides[-1], bias=False),
nn.BatchNorm2d(c_in, eps=1e-3, momentum=0.01),
nn.ReLU(),
))
self.num_bev_features = c_in
########test
def forward(self, data_dict):
"""
Args:
data_dict:
spatial_features
Returns:
"""
spatial_features = data_dict['spatial_features']
ups = []
ret_dict = {}
x = spatial_features
for i in range(len(self.blocks)):
x = self.blocks[i](x)
stride = int(spatial_features.shape[2] / x.shape[2])
ret_dict['spatial_features_%dx' % stride] = x
if len(self.deblocks) > 0:
ups.append(self.deblocks[i](x))
else:
ups.append(x)
if len(ups) > 1:
x = torch.cat(ups, dim=1)
elif len(ups) == 1:
x = ups[0]
if len(self.deblocks) > len(self.blocks):
x = self.deblocks[-1](x)
data_dict['spatial_features_2d'] = x
return data_dict
(5)对conv层及bn层参数进行剪枝
old_modules = list(model.modules())
new_modules = list(newmodel.modules())
layer_id_in_cfg = 0
start_mask = torch.ones(64)
end_mask = cfg_mask[layer_id_in_cfg]
conv_count = 0
bn_count = 0
for layer_id in range(len(old_modules)):
m0 = old_modules[layer_id]
m1 = new_modules[layer_id]
#print("old_modules is: ", old_modules)
if isinstance(m0, nn.BatchNorm2d):
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
if idx1.size == 1:
idx1 = np.resize(idx1,(1,))
if bn_count == 0 :
# If the next layer is the channel selection layer, then the current batchnorm 2d layer won't be pruned.
m1.weight.data = m0.weight.data.clone()
m1.bias.data = m0.bias.data.clone()
m1.running_mean = m0.running_mean.clone()
m1.running_var = m0.running_var.clone()
bn_count += 1
layer_id_in_cfg += 1
start_mask = end_mask.clone()
if layer_id_in_cfg < len(cfg_mask):
end_mask = cfg_mask[layer_id_in_cfg]
else:
bn_count += 1
m1.weight.data = m0.weight.data[idx1.tolist()].clone()
m1.bias.data = m0.bias.data[idx1.tolist()].clone()
m1.running_mean = m0.running_mean[idx1.tolist()].clone()
m1.running_var = m0.running_var[idx1.tolist()].clone()
layer_id_in_cfg += 1
start_mask = end_mask.clone()
if layer_id_in_cfg < len(cfg_mask): # do not change in Final FC
end_mask = cfg_mask[layer_id_in_cfg]
elif isinstance(m0, nn.Conv2d):
if conv_count == 0:
m1.weight.data = m0.weight.data.clone()
conv_count += 1
continue
if layer_id == (len(old_modules)-1):
m1.weight.data = m0.weight.data.clone()
continue
if isinstance(old_modules[layer_id+1], nn.BatchNorm2d):
# This convers the convolutions in the residual block.
# The convolutions are either after the channel selection layer or after the batch normalization layer.
conv_count += 1
idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
if idx0.size == 1:
idx0 = np.resize(idx0, (1,))
if idx1.size == 1:
idx1 = np.resize(idx1, (1,))
w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
# If the current convolution is not the last convolution in the residual block, then we can change the
# number of output channels. Currently we use `conv_count` to detect whether it is such convolution.
w1 = w1[idx1.tolist(), :, :, :].clone()
m1.weight.data = w1.clone()
torch.save({'cfg': cfg, 'model_state': newmodel.state_dict()}, os.path.join('./', 'pruned_90.pth'))
(6)使用新模型,并load参数,测试效果
首先数据预处理:
python -m pcdet.datasets.kitti.kitti_dataset create_kitti_infos tools/cfgs/dataset_configs/kitti_dataset.yaml
然后:
(1).环境位置:阵列g03,zxw_compression容器,/data/OpenPCDet-master/tools
(2).运行命令:
测试:
CUDA_VISIBLE_DEVICES=2 python test.py --cfg_file cfgs/kitti_models/pointpillar.yaml --batch_size 1 --ckpt checkpoint_epoch_90.pth
训练:
python train.py --cfg_file cfgs/kitti_models/pointpillar.yaml --batch_size 8 --epochs 100
(3)每次切换一个OpenPCDet-master,需要运行命令
python setup.py develop
(1)不是所有的模型都能用剪枝来加速,很多网络层没有BN,或者可剪枝的层数过少,增速不明显;
(2)从结果来看,PCDET中的pointpillars网络部分耗时很少,主要时间浪费在后处理中的NMS模块,还未深入研究此模块耗时原因;
(3)剪枝Backbone2d层会减少后处理速度,有待研究;
(4)对于有大量conv2d+bn组合的网络结构,网络层数较多的,例如resnet152,可以采用剪枝来加速。注:本文所说的剪枝,指的是根据bn参数,对通道剪枝,不涉及其他剪枝