pytorch中如何给网络添加mask

深度学习中,我们经常会遇到需要添加mask的场景,如:

  • nlp中为了长度对齐,需要补齐长度,但在计算attention时会将补齐位置mask掉从而不参与attention计算;
  • mask相关的预训练任务,如MLM、MAE等,需要mask掉被遮盖的token,以完成预测的预训练任务;
  • swin中,在做shift操作后,为了防止原本物理位置不相邻的区域产生交互,需要进行mask attention。
  • 计算loss时想忽略掉一些不想用来计算该loss的样本。

样例

在attention操作中,在计算attn softmax前,将被mask位置的logits设置为一个很小的数,如-10000,在计算softmax后,就会抑制掉这些位置的作用,代码如下:

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., with_qkv=True):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5  # 分母根号d
        self.with_qkv = with_qkv
        if self.with_qkv:
           self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
           self.proj = nn.Linear(dim, dim)
           self.proj_drop = nn.Dropout(proj_drop)
        self.attn_drop = nn.Dropout(attn_drop)

    def forward(self, x, attention_mask=None):
        B, N, C = x.shape
        if self.with_qkv:
           qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
           q, k, v = qkv[0], qkv[1], qkv[2]
        else:
           qkv = x.reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
           q, k, v  = qkv, qkv, qkv

        attn = (q @ k.transpose(-2, -1)) * self.scale
        if attention_mask is not None:
            attention_mask = attention_mask.to(dtype=attn.dtype)
            attention_mask = (1.0 - attention_mask) * -10000.0
            attn = attn + attention_mask
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        if self.with_qkv:
           x = self.proj(x)
           x = self.proj_drop(x)
        return x

你可能感兴趣的:(深度学习,pytorch,深度学习,attention)