[Tensorflow] 统计模型的参数数量 How to calculate the amount of parameters in my model?

import logging
logging.basicConfig(level=logging.INFO, format='%(message)s', filemode='w', filename=config.logger)

def _params_usage():
	total = 0
	prompt = []
	for v in tf.trainable_variables():
		shape = v.get_shape()
		cnt = 1
		for dim in shape:
			cnt *= dim.value
		prompt.append('{} with shape {} has {}'.format(v.name, shape, cnt))
		logging.info(prompt[-1])
		total += cnt
	prompt.append('totaling {}'.format(total))
	logging.info(prompt[-1])
	return '\n'.join(prompt)

shape is of type TensorShape. It is an iterable and each element is of type Dimension, whose attribute .value gives the raw integer of the dimension.

The above function _params_usage() prints out infos in the specified logging approach, and returns a string. This is intended to prints out in parallel to a logging file and the stdout stream.

你可能感兴趣的:(坑)