Tensorflow,用tf.while_loop编辑shape为None的tensor的每个值

import tensorflow as tf


def func1():
    return tf.constant(False)


def func2():
    return tf.constant(True)

batch_size = 4

global_tensor = tf.cast(tf.ones([batch_size, 2]), tf.float32)

input_list = tf.constant(False, shape=[1])

input_index = tf.constant(0)


def while_cond(i1, i2):
    return tf.less(i1, tf.shape(global_tensor)[0]) # 实际batch_size为None也可以


def body(i, input_list):
    one_vector = global_tensor[i]
    total = tf.reduce_sum(tf.cast(tf.math.equal(one_vector, tf.constant([1.0, 1.0])), tf.int32))
    res = tf.cond(tf.math.equal(total, tf.constant(2)), func1, func2)
    input_list = tf.concat([input_list, [tf.cast(res, dtype=tf.bool)]], 0)
    return tf.add(i, 1), input_list


result_index, result_list = tf.while_loop(while_cond, body, loop_vars=[input_index, input_list],
                                          shape_invariants=[input_index.get_shape(), tf.TensorShape([None])])

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

print(sess.run(result_index))

print(sess.run(result_list))
tensor_to_edit = tf.constant([1, 2, 3, 4])

tmp = tf.where(result_list[1:],
               tensor_to_edit,
               [100, 200, 300, 400])
print(sess.run(tmp))

print结果:
4
[False False False False False]
[100 200 300 400]

你可能感兴趣的:(TensorFlow,tensorflow,python,深度学习)