Tensorflow计算一个模型的浮点运算数

1、统计模型的浮点运算数和参数量

  • FLOPS:注意全大写,是floating point operations per second的缩写,意指每秒浮点运算次数,理解为计算速度。是一个衡量硬件性能的指标。
  • FLOPs:注意s小写,是floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量。可以用来衡量算法/模型的复杂度。
  • MACCs:是multiply-accumulate operations),也叫MAdds,意指乘-加操作,理解为计算量,MAdds 大约是 FLOPs 的一半。

    当我们使用tensorflow设计好一个神经网络模型之后,我们如何统计这个模型有多少浮点运算数(FLOPs)和多少参数量呢?我们可以使用tensorflow中的一个模块来进行统计,实例如下(以统计VGG为例):

#coding = utf-8

import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.nets import vgg

def stats_graph(graph):
	flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
	params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
	print('FLOPs: {};    Trainable params: {}'.format(flops.total_float_ops, params.total_parameters))


def main():
	with tf.Graph().as_default() as graph:
		inputs = tf.placeholder(dtype = tf.float32, shape = [1, 224, 224, 3])
		with slim.arg_scope(vgg.vgg_arg_scope()):
			_, end_points = vgg.vgg_16(inputs,
			                        num_classes=1000,
			                        is_training=True,
			                        dropout_keep_prob=0.5,
			                        spatial_squeeze=False,
			                        scope='vgg_16')		
		stats_graph(graph)

if __name__ == '__main__':
	main()

    统计的输出为:

FLOPs: 31651968442;    Trainable params: 138357544

   计算方式参考论文:《Pruning Convolutional Neural Networks for Resource Efficient Inference》中的内容:

  Tensorflow计算一个模型的浮点运算数_第1张图片 

2、常用模型的FLOPs统计 (参考)

      Tensorflow计算一个模型的浮点运算数_第2张图片

  参考:https://blog.csdn.net/leayc/article/details/81001801    

             https://robertlexis.github.io/2018/08/28/Tensorflow-%E6%A8%A1%E5%9E%8B%E6%B5%AE%E7%82%B9%E6%95%B0%E8%AE%A1%E7%AE%97%E9%87%8F%E5%92%8C%E5%8F%82%E6%95%B0%E9%87%8F%E7%BB%9F%E8%AE%A1/

你可能感兴趣的:(TensorFlow)