import torch
import numpy as np
import matplotlib.pyplot as plt
def subsequent_mask(size):
"Mask out subsequent positions."
attn_shape = (1, size, size)
print(attn_shape)
print(np.ones(attn_shape))
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
print(subsequent_mask)
return torch.from_numpy(subsequent_mask) == 0
print(subsequent_mask(5))
plt.figure(figsize=(5,5))
plt.imshow(subsequent_mask(20)[0])
print(subsequent_mask(5))
plt.figure(figsize=(5,5))
plt.imshow(subsequent_mask(20)[0])
import numpy as np
print('数组的上三角部分:\n{}'.format(np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12],[10,11,12],[10,11,12],[10,11,12]], k=-1)))
print('数组的上三角部分:\n{}'.format(np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12],[10,11,12],[10,11,12],[10,11,12]], k=0)))
print('数组的上三角部分:\n{}'.format(np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12],[10,11,12],[10,11,12],[10,11,12]], k=1)))
#输出
数组的上三角部分:
[[ 1 2 3]
[ 4 5 6]
[ 0 8 9]
[ 0 0 12]
[ 0 0 0]
[ 0 0 0]
[ 0 0 0]]
数组的上三角部分:
[[1 2 3]
[0 5 6]
[0 0 9]
[0 0 0]
[0 0 0]
[0 0 0]
[0 0 0]]
数组的上三角部分:
[[0 2 3]
[0 0 6]
[0 0 0]
[0 0 0]
[0 0 0]
[0 0 0]
[0 0 0]]
矩阵的shape是(7,3),可见k=-1是从第三行(index=2)为下标开始的,依次类推k=0是从第二行(index=1)为下标开始的,k=1是从第一行(index=0)为下标开始的
import numpy as np
print('数组的上三角部分:\n{}'.format(np.triu([[1,2,3],[4,5,6],[7,8,9]], k=-1)))
print('数组的上三角部分:\n{}'.format(np.triu([[1,2,3],[4,5,6],[7,8,9]], k=0)))
print('数组的上三角部分:\n{}'.format(np.triu([[1,2,3],[4,5,6],[7,8,9]], k=1)))
#输出
数组的上三角部分:
[[1 2 3]
[4 5 6]
[0 8 9]]
数组的上三角部分:
[[1 2 3]
[0 5 6]
[0 0 9]]
数组的上三角部分:
[[0 2 3]
[0 0 6]
[0 0 0]]
import numpy as np
print('数组的上三角部分:\n{}'.format(np.triu([[1,2,3],[4,5,6]], k=-1)))
print('数组的上三角部分:\n{}'.format(np.triu([[1,2,3],[4,5,6]], k=0)))
print('数组的上三角部分:\n{}'.format(np.triu([[1,2,3],[4,5,6]], k=1)))
#输出
数组的上三角部分:
[[1 2 3]
[4 5 6]]
数组的上三角部分:
[[1 2 3]
[0 5 6]]
数组的上三角部分:
[[0 2 3]
[0 0 6]]
从第一行可以看到对于这个2*3的矩阵,k=-1表示从第三行开始,但是矩阵没有第三行,所以原样输出
其他k的取值还是按照之前陈述的规律输出
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
作用就是numpy.ndarray类型中数字转换成uint8类型的数据
uint8表示:uint8是8位无符号整型
是将ndarray类型的数据转换成tensor类型的数据
将每个位置的数==0和零判断是否相等,如果=0,此位置为True,否为False
目的:是将下三角为0,上三角为1的矩阵进行翻转得到,下三角为True,上三角为False