tensorflow统计graph中的trainable_variables

最简单的做法: 转自: https://blog.csdn.net/feynman233/article/details/79187304, 版权归原作者所有。

print(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))

另有篇博客讲解的很详细:原文地址https://blog.csdn.net/shwan_ma/article/details/78879620,版权归原作者所有。

原博主写的很好,将常用的方法记载下来供以后学习参考。

sess.run(tf.global_varibales_initializer())

variable_name = [v.name for v in tf.trainable_variables()]

print(variable_names)

 

variable_names = [v.name for v in tf.trainable_variables()]

values = sess.run(variable_names)

for k,v in zip(variable_names, values):

    print("Variable: ", k)

    print("Shape: ", v.shape)

    print(v)

 

for variable in tf.trainable_variables():

    shape = variable.get_shape()

    variable_parameters = 1

    for dim in shape:

        variable_parameters *= dim.value

    total_parameters += variable_parameters

 

你可能感兴趣的:(tf)