tf.nn.embedding_lookup

https://blog.csdn.net/huahuazhu/article/details/77161668

  1 #encoding=utf-8
  2
  3 import tensorflow as tf
  4
  5 encode_embeddings = tf.constant([[1,2,3,4,5],[6,7,8,9,0]])
  6
  7 input_ids =tf.constant([[1,1,0],[1,0,1],[1,0, 1],[0,1, 1]])
  8 session = tf.compat.v1.Session()
  9
 10
 11 with session.as_default():
 12     # 结果results是4*3*5矩阵。
 13     results =tf.nn.embedding_lookup(encode_embeddings,input_ids)
 14     print(results)

tf.nn.embedding_lookup_第1张图片

 

你可能感兴趣的:(tf.nn.embedding_lookup)