trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)解析

在写Inception V3代码的时候,遇到这一句代码,分享一下它的工作原理

代码:trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)

1. lambda是一个匿名函数,它的作用举例说明

a = lambda x:x*x
print(a(2))

输出为4,等价于函数

def a(x):
    return x*x
print(a(2))

那么这一个函数trunc_normal就是返回 tf.truncated_normal_initializer(0.0, stddev)的值,最后产生一个平均值为0.0,标准差为stddev的截断的正太分布。具体使用这个函数的时候调用tensorflow的tf.contrib.slim就很方便啦

import tensorflow as tf

slim = tf.contrib.slim
trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
weights = slim.variable('weights',
                             shape=[3 , 3], #形状
                             #参数初始化
                             initializer=trunc_normal(0.1),
                             )
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print(sess.run(weights))

结果如下:

[[ 0.11840882  0.04289966 -0.02131811]
 [ 0.06113978 -0.03785787 -0.00641177]
 [ 0.08828283 -0.01430409  0.02136735]]

你可能感兴趣的:(tensorflow)