DATAWHALE-动手学深度学习PyTorch 笔记记录2 attention mask

【Attention中mask pad的weight的做法】
在attention中,对attention score进行softmax时,需要考虑到query与pad计算得到的score应该忽略。我们在处理时可以先正常地用高维tensor形式将所有score计算出来,然后根据key的句长将pad所在位置的weight进行mask掉。
下面的代码实现了给定二维tensor X,根据X_len将X中指定位置替换为value值。

def SequenceMask(X, X_len,value=-1e6):
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float)[None, :] < X_len[:, None]
    X[~mask]=value
    return X

实际实现的时候,由Q和K直接计算出的weight Tensor的size为[batch_size, Q_seq_len, K_seq_len],valid_len的size为[batch_size]。在每个样本中需要mask掉的pad的位置都是相同的,也就是有效句长都是相同的。因此我们需要对weight和valid_len都做一些变换,使得能够直接应用sequence_mask来处理。

valid_len = valid_len.view(-1,1).repeat(1, X.size(1)).view(-1)
X = X.view(-1, X.size(-1)) 
#将X变为二维张量

DATAWHALE-动手学深度学习PyTorch 笔记记录2 attention mask_第1张图片

你可能感兴趣的:(Deep,Learning,python)