这个函数目前我主要用于数据填充时候使用。
这个是官方定义,耐心看完解释再看后面的例子,你会一下就懂了。
# 函数定义
sequence_mask(
lengths,
maxlen=None,
dtype=tf.bool,
name=None
)
# 返回数据
return mask类型数据
mask
张量,默认其中内部元素类型是tf.bool
(布尔变量)tf.Session()
打印可以得到一个array
数据。注解:一般实际代码中选择数据类型为tf.float32
,这样True
会变成1.,同理False
变成0.,看不懂可以继续往下看
1.返回值mask
张量:默认mask
张量就是布尔格式的一种张量表达,只有True和 False 格式,也可以通过参数dtype
指定其他数据格式。
2.参数lengths
:顾名思义表示的是长度;可以是标量,也可以是列表 [ ] ,也可以是二维列表[ [ ],[ ] ,…],甚至是多维列表…。一般列表类型的用的比较多
3.参数maxlen
:当默认None
,默认从lengths
中获取最大的那个数字,决定返回mask
张量的长度;当为N时,返回的是N长度。
如果觉得晦涩,举例,看完就懂了:
import tensorflow as tf
lenght = 4
mask_data = tf.sequence_mask(lengths=lenght)
# 输出结果,输出结果是长度为4的array,前四个True
array([ True, True, True, True])
# 定义maxlen时
mask_data = tf.sequence_mask(lengths=lenght,maxlen=6)
# 输出结果,输出结果是长度为6的array,前四个True
array([ True, True, True, True, False, False])
# 定义dtype时
mask_data = tf.sequence_mask(lengths=lenght,maxlen=6,dtype=tf.float32)
# 输出结果,输出结果是长度为6的array,前四个1.0
array([1., 1., 1., 1., 0., 0.], dtype=float32)
batch_data
有10个数据,每个数据是一个句子,每个句子不可能是一样长的,肯定有短的需要填充0元素,那么lengths
就专门记录每个句子的长度的。# 比如这个lenght就是记录了第一个句子2个单词,第二个句子2个单词,第三个句子4个单词
lenght = [2,2,4]
mask_data = tf.sequence_mask(lengths=lenght)
# 长度为max(lenght)
array([[ True, True, False, False],
[ True, True, False, False],
[ True, True, True, True]])
# 定义maxlen时
mask_data = tf.sequence_mask(lengths=lenght,maxlen=6)
# 长度为maxlen
array([[ True, True, False, False, False, False],
[ True, True, False, False, False, False],
[ True, True, True, True, False, False]])
# 定义dtype时
mask_data = tf.sequence_mask(lengths=lenght,maxlen=6,dtype=tf.float32)
# 长度为maxlen,数据格式为float32
array([[1., 1., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0.],
[1., 1., 1., 1., 0., 0.]], dtype=float32)
lenght = [[2,2,4],[3,4,5]]
mask_data = tf.sequence_mask(lengths=lenght)
# 输出
array([[[ True, True, False, False, False],
[ True, True, False, False, False],
[ True, True, True, True, False]],
[[ True, True, True, False, False],
[ True, True, True, True, False],
[ True, True, True, True, True]]])
这个填充很多时候lenghts
不是我们举例子这样使用的固定长度length
,大多数时候都是使用了tf.data.Dataset
得到的数据。
应用场景主要是在填充计算时候使用,比如你把没有单词的位置填充了0,如果纳入了前向传播计算,影响了最终经验损失函数的结果。那么我们如果通过tf.sequence_mask
得到的mask
张量,与损失函数结果进行对照相乘,可以去掉无用的损失值,保证了计算的准确性