[tensorflow] 如何从pb模型文件中获得参数信息 How to obtain parameters information from a tensorflow .pb file?

因为要和SOTA比较模型的复杂度,我想知道参数数量。但是模型文件不是tensorflow checkpoint,而是pb文件,我发现当导入graph后,tf.trainable_variables()返回空。
Problem setting : I need to compare with state-of-the-arts the model complexity so the model parameter amount is needed. However the provided model file isn’t the ckpt file, but pb file, and the variables returned by tf.trainable_variables() is found empty.

这个回答给出了方法。
This answer gives the solution.

举例:
In my case:

# import graph
with open('spmc_120_160_4x3f.pb', 'rb') as f:
	graph_def = tf.GraphDef()
	graph_def.ParseFromString(f.read())
	output = tf.import_graph_def(graph_def, input_map={'Placeholder:0': frames_lr}, return_elements=['output:0'])
	output = output[0]
# ... other codes
# obtain variables
constant_values = {}

with tf.Session() as sess:
	constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
	for constant_op in constant_ops:
		constant_values[constant_op.name] = sess.run(constant_op.outputs[0])

# printing variables
print_params(constant_values)

def print_params(constant_values):
	total = 0
	prompt = []
	forbidden = ['shape','stack']
	for k,v in constant_values.items():
		# filtering some by checking ndim and name
		if v.ndim<1: continue
		if v.ndim==1:
			token = k.split(r'/')[-1]
			flag = False
			for word in forbidden:
				if token.find(word)!=-1:
					flag = True
					break
			if flag:
				continue

		shape = v.shape
		cnt = 1
		for dim in shape:
			cnt *= dim
		prompt.append('{} with shape {} has {}'.format(k, shape, cnt))
		print(prompt[-1])
		total += cnt
	prompt.append('totaling {}'.format(total))
	print(prompt[-1])
	return prompt

因为导入的都是constant节点,而我需要的其实是trainable_variables,所以我只能手动的根据ndim和name过滤掉一些。
As mentioned in the answer, the imported nodes in the graph are constant ones, which mix constants and variables, and the latter ones are what is needed. So i have to manually filter the usesless ones, by checking their ndim and name.

你可能感兴趣的:(坑)