tensorflow计算参数的数量以及FLOPs的估算

1 参数量的计算
该函数需要在训练的函数中调用即可执行,可以得出该网络执行的总参数。

def count():
    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    return total_parameters

2 FLOPs的计算

def count_flops(graph):
    flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
	print('FLOPs: {}'.format(flops.total_float_ops))

这两个计算参数和FLOPs的方法都必须要提前导入import tensorflow as tf头文件,才可以使用。
其中某些部分是参考:tensorflow参数及FLOPs估算

你可能感兴趣的:(程序)