tf.estimator tf.data 处理混合的不同的数据

import tensorflow as tf

data_type = tf.constant([1, 2, 1, 2])
where_index1 = tf.where(tf.equal(data_type, 1))
where_index2 = tf.where(tf.equal(data_type, 2))

data = tf.constant([[10,10],[20,20],[30,30],[40,40]])

data1 = tf.gather_nd(data,where_index1)
data2 = tf.gather_nd(data,where_index2)

sess = tf.Session()

print(sess.run(data1))
print(sess.run(data2))

print结果
[[10 10]
[30 30]]
[[20 20]
[40 40]]

你可能感兴趣的:(TensorFlow)