caffe:BN层转batchnorm+scale修改prototxt

#coding=UTF-8
import sys
sys.path.insert(0,'/home/cdli/ECO2/caffe_3d/python')
import copy
from caffe.proto import caffe_pb2
from google.protobuf import text_format
import google

def create_layer(base_name_,type_,bottom_,top_):
	layer = caffe_pb2.LayerParameter()
	layer.name = base_name_+'/'+type_
	layer.type = type_
	if type_ == 'batchnorm':
		layer.bottom.append(bottom)
		layer.top.append(top)
		temp = caffe_pb2.ParamSpec()
		temp.lr_mult = 0
		temp.decay_mult = 0
		layer.param.append(temp)
		layer.param.append(temp)
		layer.param.append(temp)
		layer.batch_norm_param.use_global_stats = False
		layer.batch_norm_param.eps = 0.00001
	elif type_ == 'scale':
		layer.bottom.append(top)
		layer.top.append(top)
		temp = caffe_pb2.ParamSpec()
		temp.lr_mult = 0.2
		temp.decay_mult = 0.2
		layer.param.append(temp)
		layer.param.append(temp)
		layer.scale_param.filler.value = 1
		layer.scale_param.bias_filler.value = 0
	return layer

net1 = caffe_pb2.NetParameter()
net2 = copy.copy(net1)
deploy = 'deploy-pool3.prototxt'
text_format.Merge(open(deploy).read(), net1) #把文本内容读进


layers = net1.layer
for i, l in enumerate(layers):
	if str(l.type)!='BN':
		continue
	name = str(l.name)
	print name
	bottom = str(l.bottom[0])
	top = str(l.top[0])
	
	batchnorm = create_layer(name,'batchnorm',bottom,top)
	scale = create_layer(name,'scale',bottom,top)
	layers.pop(i)
	layers.insert(i, batchnorm)
	layers.insert(i, scale)
with open(deploy.split('.')[0]+'-batchnorm.prototxt','w') as f:
	f.write(str(net1))


你可能感兴趣的:(caffe)