learn掩码张量

目录

1、什么是掩码张量

2、掩码张量的作用

3、代码演示

(1)、定义一个上三角矩阵,k=0或者 k默认为 0

(2)、k=1

(3)、k=-1

4、掩码张量代码实现

(1)、输出效果

(2)、输出效果分析


1、什么是掩码张量

  • 掩就是代表遮掩,码就是张量中的数值,它的尺寸不定,里面只有 1 和 0 的元素,代表的位置被遮掩或者不被遮掩,至于是 0 位置被遮掩还是 1 位置被遮掩可以自己定义,因此它的作用就是让另外一个张量中的数值被遮掩,也可以说成是被替换,它的表现形式是一个张量

2、掩码张量的作用

  • 在transformers中,掩码张量的主要作用应用在 attention时,有一些生成的attention张量中的值计算有可能已知了未来信息而得到的,未来信息被看到是因为训练时会把整个输出结果都一次性进行 Embedding,但是理论上解码器的输出却不是一次就能产生最终结果的,而是一次次通过上次结果综合得出的。因此,未来的信息可能被提前利用,所以,我们会进行遮掩

3、代码演示

(1)、定义一个上三角矩阵,k=0或者 k默认为 0

attn_shape = (1,3,3) # 定义掩码张量的形状
sub_mask = np.triu(np.ones(attn_shape), k = 0).astype('uint8') # 定义一个上三角矩阵,元素为1,再使用其中的数据类型变为无符号8位整形,其中 k=1 是将上三角矩阵的所有为 1 的元素向上移动一行
print(sub_mask)

[[[1 1 1]
  [0 1 1]
  [0 0 1]]]

(2)、k=1

attn_shape = (1,3,3) # 定义掩码张量的形状
sub_mask = np.triu(np.ones(attn_shape), k = 1).astype('uint8') # 定义一个上三角矩阵,元素为1,再使用其中的数据类型变为无符号8位整形,其中 k=1 是将上三角矩阵的所有为 1 的元素向上移动一行
print(sub_mask)

[[[0 1 1]
  [0 0 1]
  [0 0 0]]]

(3)、k=-1

attn_shape = (1,3,3) # 定义掩码张量的形状
sub_mask = np.triu(np.ones(attn_shape), k = -1).astype('uint8') # 定义一个上三角矩阵,元素为1,再使用其中的数据类型变为无符号8位整形,其中 k=1 是将上三角矩阵的所有为 1 的元素向上移动一行
print(sub_mask)

[[[1 1 1]
  [1 1 1]
  [0 1 1]]]

4、掩码张量代码实现

import numpy as np
import torch
def subsequent_mask(size):
    """
    :param size: 生成向后遮掩的掩码张量,参数 size 是掩码张量的最后两个维度大小,它的最后两个维度形成一个方阵
    :return:
    """
    attn_shape = (1,size,size) # 定义掩码张量的形状
    subsequent_mask = np.triu(np.ones(attn_shape),k = 1).astype('uint8') # 定义一个上三角矩阵,元素为1,再使用其中的数据类型变为无符号8位整形
    return torch.from_numpy(1 - subsequent_mask) # 先将numpy 类型转化为 tensor,再做三角的翻转,将位置为 0 的地方变为 1,将位置为 1 的方变为 0
size = 5
sm = subsequent_mask(size)
print("sm :",sm)
import matplotlib.pyplot as plt
plt.figure(figsize=(5,5))
plt.imshow(subsequent_mask(20)[0])

(1)、输出效果

(2)、输出效果分析

  • 通过观察可视化方阵,黄色是 1 的部分,这里代表被遮掩,紫色代表没有被遮掩的信息,横坐标代表目标词汇的位置,纵坐标代表可查看的位置
  • 我们看到,在 0 的位置我们以看望过去都是黄色的,都被遮掩了,1的位置一眼望过去还是黄色,说明第一次词还没有产生,从第二个位置看过去,就能看到位置 1 的词,其他位置看不到,以此类推

你可能感兴趣的:(Tranformers,人工智能)