根据BN裁剪模型channel

import sys                                                                                                    
import caffe                                                                                                                                     
import numpy as np                                                       
                                                                         
percent = 0.5                                                            
np.set_printoptions(threshold=sys.maxsize)                               
                                                                         
MODEL_FILE = '********.prototxt' #模型文件                           
PRETRAIN_FILE = '*********.caffemodel'   #参数文件                 
                                                                         
params_txt = 'params.txt'   #保存参数文件                                             
pf = open(params_txt, 'w')                                               
                                                                         
total = 0  #统计BN参数的个数                                                           
                                                                         
net = caffe.Net(MODEL_FILE, PRETRAIN_FILE, caffe.TEST)   #加载模型和参数                
print("load model successful")                                           
for param_name in net.params.keys():                                     
    weight = net.params[param_name][0].data                              
    bias = net.params[param_name][1].data                                
    if 'bn' in param_name:                                               
        total += weight.shape[0]                                         
print("the number of bn is :", total)                                    
bn = []                                                                  
index = 0                                                                
small_num = 0                                                            
for param_name in net.params.keys():                                     
    weight = net.params[param_name][0].data  #gama                       
    bias = net.params[param_name][1].data  #beta                            
    if 'bn' in param_name:                                               
        size = weight.shape[0]                                           
        weight.shape = (-1, 1)                                           
        for w in weight:                                                 
            bn.append(abs(w))  
                                          
y = sorted(bn)    #根据BN 的gama值进行排序                                                       
thre_index = int(total * percent)    #根据设置的裁剪比例来确定gama阈值

print("thread is :", thre)                                               │
# 这里并没有裁剪,只是对要裁剪的gama置0,测试了一下结果 
# 可能会出现某些layer的channel被全部裁剪的情况,对于这种情况,保留其gama绝对值最大的那个channel                                                                       
pruned = 0                                                               │
cfg = []                                                                 │
cfg_mask = []                                                            │
for param_name in net.params.keys():                                     │
    weight = net.params[param_name][0].data                              │
    bias = net.params[param_name][1].data                                │
    if 'bn' in param_name:                                               │
        size = weight.shape[0]                                           │
        print(param_name)                                                │
        num = 0                                                          │
        index = 0                                                        │
        max_id = 0                                                       │
        max_gama = 0                                                     │
        max_b = 0                                                        │
                                                                         │
        for w in weight:                                                 │
            if abs(w) > max_gama:                                        │
                max_id = index                                           │
                max_gama = w                                             │
                max_b = net.params[param_name][1].data[index]            │
                                                                         │
            if abs(w) < thre:                                            │
                num +=1                                                  │
                small_num +=1                                            │
                net.params[param_name][0].data[index] = 0                │
                net.params[param_name][1].data[index] = 0                │
                                                                         │
            index +=1                                                    │
        if num == size:                                                  │
            net.params[param_name][0].data[max_id] = max_gama            │
            net.params[param_name][1].data[max_id] = max_b               │
                                                                         │
        print(size,num)                                                  │
print(small_num)#    

net.save('at_least_one_new.caffemodel')       

 

你可能感兴趣的:(深度学习)