tensorflow set contain

如果知道tensor的长度,比较简单

import tensorflow as tf

one_vector = tf.constant([0,111,222,333,0])
tmp_list = []
for tmp_index in range(0, 3):
    tmp_list.append(tf.cast(tf.math.equal(one_vector[tmp_index:tmp_index+3], 
    tf.constant([111,222,333])),tf.int32))
    
total = tf.reduce_sum(tmp_list)
 
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

print(sess.run(tmp_list))
print(sess.run(total))

print结果:
[array([0, 0, 0], dtype=int32), array([1, 1, 1], dtype=int32), array([0, 0, 0], dtype=int32)]
3

你可能感兴趣的:(TensorFlow,tensorflow,python,numpy)