讲解下swin transformer attention mask生成的核心

 本文启发来源于多篇博客。文末附有一些链接

h_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
w_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
cnt = 0
for h in h_slices:
    for w in w_slices:
        img_mask[:, h, w, :] = cnt
        cnt += 1
 
mask_windows = window_partition(img_mask, window_size)  # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)
 
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

上面的几行的目的是生成一个类似于这样的矩阵

         [[0., 0., 0., 0., 1., 1., 2., 2.],
         [0., 0., 0., 0., 1., 1., 2., 2.],
         [0., 0., 0., 0., 1., 1., 2., 2.],
         [0., 0., 0., 0., 1., 1., 2., 2.],
         [3., 3., 3., 3., 4., 4., 5., 5.],
         [3., 3., 3., 3., 4., 4., 5., 5.],
         [6., 6., 6., 6., 7., 7., 8., 8.],
         [6., 6., 6., 6., 7., 7., 8., 8.]]

后来在partition了下变成类似这样
        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 2., 2., 1., 1., 2., 2., 1., 1., 2., 2., 1., 1., 2., 2.],
         [3., 3., 3., 3., 3., 3., 3., 3., 6., 6., 6., 6., 6., 6., 6., 6.],
         [4., 4., 5., 5., 4., 4., 5., 5., 7., 7., 8., 8., 7., 7., 8., 8.]]

维度是4 ✖49

我想讲的核心是这个

attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)

这里分别在第一维度第二维度扩充

4*49*1-4*1*49维,会自动广播成4*49*49。

我重点就说这里,因为前面计算attention Q✖V后

我们也会得到一个49*49的矩阵。他是从7*7窗口展平后的两个向量进行乘法计算的。

而这里的广播分别是行和列的广播,即行复制和列复制。

我拿几个元素举例,需要从矩阵乘法去思考。

第一行第一列的元素代表相乘时的第一个元素和向量2的第一个元素,二者相减会为0也就意味着他们来自同一个区,可以计算attention

第一行第二列的元素代表相乘时的第一个元素和向量2的第二个元素,二者相减会为0也就意味着他们来自同一个区,可以计算attention

而这样的相减如果有一个元素不为0,从他是第几行第几列就能看出他来源于第一个向量的第几个元素和第二个向量的第几个元素,这个位置相乘的结果不能做attention

Swin-Transformer网络结构详解

【深度学习】详解 Swin Transformer (SwinT)

史上最详细的Swin-Transformer 掩码机制(mask of window attentation)————shaoshuai

你可能感兴趣的:(transformer,深度学习,人工智能)