利用Network Slimmng对FSSD进行prune,在voc07上获得79.64的map,TITAN X上150FPS的效果,链接:https://github.com/dlyldxwl/fssd.pytorch 觉得都有用的看官朋友们,给个star
caffemodel里每一层的卷积核维度为ouput-channel×input-channel×height×weight. 本篇博客是针对每一个维度为input-channel×height×weight的卷积核, 计算出ouput-channel×input-channel个权重和,小于给定阈值的认为低效连接,给予置零操作.代码写的比较粗糙,注释给的较为详细,无需多言.
# coding:utf-8
# by chen yh
import caffe
import numpy as np
import matplotlib.pyplot as plt
def weight_0(prototxt,model,layer,threshold):
caffe.set_mode_gpu()
net=caffe.Net(prototxt,model,caffe.TEST)
weight = net.params[layer][0].data
bias = net.params[layer][1].data
sum_l1=[]
for i in range(weight.shape[0]):
for j in range(weight.shape[1]):
sum_l1.append((i,j,np.sum(abs(weight[i,j,:,:]))))#i是核的顺序,j是每个卷积核与前面某个channel的连接顺序,求出每个连接的类似于L1范数的权重和,加上i,j是为了后续判断weight的时候好直接处理到原weight
display(sum_l1,128)#从小到大排序后打印出前128个L1范数
l1_plot(sum_l1)#画出L1范数关于out*input的坐标图,以确定多少个需要修剪.
weight_l1 = []
for i in sum_l1:
weight_l1.append(i[2]) #得到仅含有l1范数的列表
for i,weight_sum in enumerate(weight_l1):
if weight_sum < threshold:
out_channel_sort = sum_l1[i][0]
input_channel_sort = sum_l1[i][1]
weight[out_channel_sort, input_channel_sort, :, :] = 0 #小于阈值的,weight置0
net.save("new.caffemodel")
def l1_plot(weight_l1):
weight_l1_n=[]
for i in weight_l1:
weight_l1_n.append(i[2])
weight_l1_n.sort()
x=[i for i in range(len(weight_l1_n))]
plt.plot(x,weight_l1_n)
plt.legend()
plt.show()
def display(weight_l1,threshold):
weight_l1_n=[]
for i in weight_l1:
weight_l1_n.append(i[2])
weight_l1_n.sort()
print [weight_l1_n[i] for i in range(threshold)]
root = "/home/cyh/python_file/"
prototxt = root+ "deploy.prototxt"
model = root + "VGG_coco_SSD_300x300_iter_400000.caffemodel"
weight_0(prototxt,model,'fc7',0.0001)
说明: 1.代码中L1范数并不是严格的L1范数,只是每个weight*height的绝对值的和.
2.代码仅仅实现了weight置0的步骤.如果希望减小存储空间,需将该稀疏矩阵存储为CSC或者CSR格式,可以移步看一下这个链接,有类似的原理介绍也有代码,这个实现应该不难.
3.本文代码和deep compression意义 是不一样的,一是论文是以一定的比率修剪,我是以阈值修剪;二是论文好像是修剪单个权重,而我是以weight*height为单位,但是相同的是都不能提高forward inference时间, 因此所有后续的存储,retrain等等我也没有继续,如果有做的小伙伴希望和我分享一下.
明天有空的话我会注释一下我写的channel pruning的py文件,然后再贴出来,这种方式就可以加快inference的时间了,而且也不用注意存储等问题.
代码给出链接
channel pruning 的py文件已上传,见链接