tf.truncated_normal 理解

truncated_normal(
    shape,
    mean=0.0,
    stddev=1.0,
    dtype=tf.float32,
    seed=None,
    name=None
)

功能说明:

产生截断正态分布随机数,取值范围为 [ mean - 2 * stddev, mean + 2 * stddev ]

参数列表:

参数名 必选 类型 说明
shape 1 维整形张量或 array 输出张量的维度
mean 0 维张量或数值 均值
stddev 0 维张量或数值 标准差
dtype dtype 输出类型
seed 数值 随机种子,若 seed 赋值,每次产生相同随机数
name string 运算名称


 

import tensorflow as tf
import matplotlib.pyplot as plt

tn = tf.truncated_normal([20],mean=5,stddev=1)

sess = tf.Session()
ov = sess.run(tn)
print(ov)
plt.plot(ov)
plt.show()

[3.183846  6.06047   4.565305  5.194239  5.8779397 6.048414  5.664632
 5.2293634 4.4598646 5.3759885 4.3167524 5.5291214 4.037938  5.0970454
 4.8114433 4.9737196 4.246079  5.5007358 5.0794163 4.2945256]

tf.truncated_normal 理解_第1张图片

你可能感兴趣的:(tensoflow)