[tf]train 和 test交替进行

tf.reset_default_graph()
train_set = tf.Variable([1, 2, 3, 4], name='train')
test_set = tf.Variable([5, 6, 7, 8], name='test')
is_train = tf.placeholder(dtype=tf.bool)
c = tf.cond(is_train, lambda: train_set*2, lambda: test_set*2)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(20):
    print('run train:', sess.run(c, feed_dict={is_train: True}))
    if i % 10 == 0:
        print('run test:', sess.run(c, feed_dict={is_train: False}))
logit = tf.cond(is_train, lambda: MODEL(train...), lambda: MODEL(test...)

你可能感兴趣的:([tf]train 和 test交替进行)