MultiHeadAttension源码解析——batch_first参数含义

文章目录

  • 前言
  • 1. 问题描述
  • 2. 相关概念

前言

简单介绍batch_first参数的含义和相关概念。

1. 问题描述

Pytorch的多头注意力(MultiHeadAttension)代码中,有一个batch_first参数,在传递参数的时候必须注意。

def forward(self, query: Tensor, key: Tensor, value: Tensor, 
                key_padding_mask: Optional[Tensor] = None,
                need_weights: bool = True, 
                attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
      
    if self.batch_first:
        query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

官方文档对batch_first的解释如下。

batch_first – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False (seq, batch, feature).

简单翻译一下:如果设置为True,输入和输出张量按照(batch,seq,feature)的顺序提供。默认值是False。按照(seq,batch,feature)的顺序。

通过查看源码发现,如果按照(batch,seq,feature)的顺序传入参数,并且batch_first设置为True,那么会自动转换成(seq,batch,feature)的顺序,输出结果的时候,再转换回来。

if self.batch_first:
    query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

转换代码

# 参数1和0意味着交换第一和第二个索引,也就是batch和seq。
transpose(1, 0)

2. 相关概念

batch、seq和feature

什么是batch:批量大小,就是一次传入的序列(句子)的数量。

什么是seq:序列长度,即单词数量。

什么是feature:特征长度,每个单词向量(Embedding)的长度。

也记为N(批量)、T(序列)、C(特征)。

为什么默认不是批量在前呢?

根据这篇文章https://zhuanlan.zhihu.com/p/32103001的解释:

为了便于并行计算,cuDNN中的RNN模型提供的API就是batch_size在第二维度。

虽然上文是按照RNN来解释的,但是应该对注意力模型也适用。

至于cuDNN这样排序的原因,是因为batch first=True意味着模型的输入(一个Tensor)在内存中存储时,先存储第一个sequence,再存储第二个,而如果是seq放在前面,模型的输入在内存中,先存储所有序列的第一个单元,然后是第二个单元。

你可能感兴趣的:(深度学习,pytorch,深度学习)