tensorflow 构造非零mask

import tensorflow as tf
bert_input_ids = tf.constant([[1,2,0,3],[1,0,2,0]],dtype=tf.float32)
tmp = tf.sign(bert_input_ids)
bert_mask = tf.cast(tmp, tf.float32)

sess = tf.Session()
print(sess.run(bert_mask))
print(sess.run(bert_input_ids * bert_mask))

print结果:

[[1. 1. 0. 1.]
 [1. 0. 1. 0.]]
[[1. 2. 0. 3.]
 [1. 0. 2. 0.]]

你可能感兴趣的:(TensorFlow)