Tensorflow计算网络参数量和计算量FLOPs

Author: 杭州电子科技大学-自动化学院-智能系统和机器人研究中心-Jolen Xie

先引入头文件

import tensorflow as tf

1.计算参数量

def count_param():       # 计算网络参数量
    total_parameters = 0
    for v in tf.trainable_variables():
        shape = v.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print('网络总参数量:', total_parameters)

2. 计算FLOPS

def count_flops(graph):
    flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
    print('FLOPs: {}'.format(flops.total_float_ops))

3.一起算

def stats_graph(graph):
    flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
    params = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.trainable_variables_parameter())
    print('FLOPs: {};    Trainable params: {}'.format(flops.total_float_ops, params.total_parameters))

一起算的使用方法:

...
graph =tf.get_default_graph()
stats_graph(graph)
...

你可能感兴趣的:(tensorflow)