tf.sequence_mask使用方法

def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None):
参数含义:
lengths: 整数张量,所有值必须小于maxlen
maxlen:标量整数张量,返回张量的最后维数的大小。默认值是lengths中的最大值。
dtype:输出张良类型,
name:操作的名称

import tensorflow as tf

a=tf.sequence_mask([1,2,4,6],maxlen=7)
b=tf.sequence_mask([[1,2],[4,6]])

with tf.Session() as sess:
    print('a:',sess.run(a))
    print('b:',sess.run(b))

输出:
tf.sequence_mask使用方法_第1张图片
a的输出为4行7列的数据,4是输入lengths列表的长度,7就是maxlen的参数,这四行的前1,2,4,6个数据为True,其他为False

b的shape为(2,2,6),(2,2)是lengths的形状,6是lengths中的最大值,由于maxlen没有给定,所以默认为输入lengths中的最大值6。

你可能感兴趣的:(学习,tensorflow)