Tensorflow,使用tf.where编辑tensor的每个值

import tensorflow as tf

judge_list1 = [True, True, False, False]
judge_list2 = [tf.constant(True), tf.constant(True), tf.constant(False), tf.constant(False)]

input_tensor1 = [1, 2, 3, 4]
input_tensor2 = [tf.constant(1), tf.constant(2), tf.constant(3), tf.constant(4)]

result1 = tf.where(judge_list1,
                input_tensor1,
                [100, 200, 300, 400])

result2 = tf.where(judge_list2,
                input_tensor2,
                [100, 200, 300, 400]) 
                  
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
print(sess.run(result1))
print(sess.run(result2))

print结果:
[ 1 2 300 400]
[ 1 2 300 400]

你可能感兴趣的:(TensorFlow,tensorflow,深度学习,python)