会话级种子:seed
当在代码中使用了随机数,但是希望代码在不同时间或者不同的机器上运行能够得到相同的随机数,以至于能够得到相同的结果,那么久需要到设置随机函数的seed 参数,对应的变量可以跨session生成 相同的随机数:
例子:
tf.reset_default_graph()#函数用于清除默认图形堆栈并重置全局默认图形
a= tf.random_normal([1],mean=0, stddev=1, seed=1)
b= tf.random_normal([1],mean=0,stddev=1)
print('Session1')
with tf.Session() as sess1:
print(sess1.run(a))
print(sess1.run(a))
print(sess1.run(b))
print(sess1.run(b))
print('Session2')
with tf.Session() as sess2:
print(sess2.run(a))
print(sess2.run(a))
print(sess2.run(b))
print(sess2.run(b))
结果:
Session1
[-0.8113182]
[0.6396971]
[1.1263528]
[1.546696]
Session2
[-0.8113182]
[0.6396971]
[-0.5055166]
[-0.54076374]
可以看出设置了a设置了seed=1之后,在不同的Session中a产生的随机数是一致的,而b在不同的Session中产生的随机数是不一致的。
图级种子:tf.set_random_seed
如果不想一个一个的设置随机种子seed,那么可以使用全局设置tf.set_random_seed()函数,使用之后后面设置的随机数都不需要设置seed,而可以跨会话生成相同的随机数。
例子:
tf.reset_default_graph()#函数用于清除默认图形堆栈并重置全局默认图形
tf.set_random_seed(1)#设置全局随机种子
a= tf.random_normal([1],mean=0, stddev=1)
b= tf.random_normal([1],mean=0,stddev=1)
print('Session1')
with tf.Session() as sess1:
print(sess1.run(a))
print(sess1.run(a))
print(sess1.run(b))
print(sess1.run(b))
print('Session2')
with tf.Session() as sess2:
print(sess2.run(a))
print(sess2.run(a))
print(sess2.run(b))
print(sess2.run(b))
结果:
Session1
[-0.67086124]
[0.9259123]
[-0.3476087]
[-0.03807747]
Session2
[-0.67086124]
[0.9259123]
[-0.3476087]
[-0.03807747]
上面例子我们也发现了,即使设置了随机种子,但是在同一个会话当中,产生的随机数也会不一致,那么如何解决呢?
情况一:定义两个变量的随机生成函数一样,种子一样,结果一样
例子:
tf.reset_default_graph()#函数用于清除默认图形堆栈并重置全局默认图形
a= tf.random_normal([1],mean=0, stddev=1,seed=2)
b= tf.random_normal([1],mean=0,stddev=1,seed=2)
print('Session1')
with tf.Session() as sess1:
print('a')
print(sess1.run(a))
print(sess1.run(a))
print('b')
print(sess1.run(b))
print(sess1.run(b))
结果:
Session1
a
[-0.85811085]
[-0.20793143]
b
[-0.85811085]
[-0.20793143]
情况二:设置为变量variable,得到同一个session可复用的结果:
tf.reset_default_graph()#函数用于清除默认图形堆栈并重置全局默认图形
a= tf.Variable(tf.random_normal([1],mean=0, stddev=1,seed=2))
init_op=tf.global_variables_initializer()
print('Session1')
with tf.Session() as sess1:
sess1.run(init_op)
print('a')
print(sess1.run(a))
print(sess1.run(a))
结果:
Session1
a
[-0.85811085]
[-0.85811085]