tensorflow中sess.run()越来越慢的问题解决

tensorflow中sess.run()越来越慢的问题解决

在我们运行tf.Session.run()的次数越多,会发现程序的输出越来越慢,这是因为直接用run去读取数据是很慢的,所以run越多,就越多的数据被缓存,导致速度越来越慢。
先上一个运行很慢的例子:
#上很慢的例子

import matplotlib.pyplot as plt
import tensorflow as tf
y = []
z = []
N = 200
#global_step = tf.Variable(0, name='global_step', trainable=False)
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for global_step in range(N):
        # cycle=False
        print('step:',global_step)
        learing_rate1 = tf.train.polynomial_decay(
            learning_rate=0.1, global_step=global_step, decay_steps=50,
            end_learning_rate=0.01, power=0.5, cycle=False)
        # cycle=True
        learing_rate2 = tf.train.polynomial_decay(
            learning_rate=0.1, global_step=global_step, decay_steps=50,
            end_learning_rate=0.01, power=0.5, cycle=True)  #这里是直接run,没有定义传递数值的变量
        lr1 = sess.run([learing_rate1])
        lr2 = sess.run([learing_rate2])
        y.append(lr1[0])
        z.append(lr2[0])
 
x = range(N)
fig = plt.figure()
ax = fig.add_subplot(111)
plt.plot(x, z, 'g-', linewidth=2)
plt.plot(x, y, 'r--', linewidth=2)
plt.title('polynomial_decay')
ax.set_xlabel('step')
ax.set_ylabel('learing rate')
plt.show()

你会发现,越到后面,运行的越慢。
下面上快速的代码,只是定义了了一个变量,问题就解决了。

import matplotlib.pyplot as plt
import tensorflow as tf
y = []
z = []
N = 5000
#global_step = tf.Variable(0, name='global_step', trainable=False)
global_step = tf.Variable(tf.constant(0),name='test')  #这里定义了一个变量,并且在循环之外,run的时候只需要feed进取就行了,不同点就在这里
learing_rate1 = tf.train.cosine_decay_restarts(
            learning_rate=0.1, global_step=global_step,t_mul=2.0,m_mul=0.7, alpha=0.0, first_decay_steps=200)
learing_rate2 = tf.train.cosine_decay_restarts(
            learning_rate=0.1, global_step=global_step, t_mul=2.0,m_mul=1.0, alpha=0.0,first_decay_steps=200)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for global_step1 in range(N):
        lr = sess.run([learing_rate1,learing_rate2],feed_dict={global_step:global_step1})
        # lr2 = learing_rate2.run()
        #print(lr)
        y.append(lr[0])
        z.append(lr[1])

x = range(N)
fig = plt.figure()
ax = fig.add_subplot(111)
plt.plot(x, y, 'r-', linewidth=2)
#plt.plot(x, z, 'g-', linewidth=2)
plt.title('cosine_decay_restarts')
ax.set_xlabel('step')
ax.set_ylabel('learing rate')
plt.show()

总结:不要直接把要run的内容放在循环体内,应该放在循环体外,在循环之外定义变量,最后将变量feed进去,就能够避免越来越慢。

你可能感兴趣的:(tensorflow中sess.run()越来越慢的问题解决)