Multi-Head Attention 原理与代码实现

1.Multi-Head Attention 结构

Multi-Head Attention 原理与代码实现_第1张图片

2.Multi-Head Attention 计算

        Multi-Head Attention 是由多个 Self-Attention 组合形成的。对于同一个文本,一个Attention获得一个表示空间,如果多个Attention,则可以获得多个不同的表示空间。基于这种想法,就有了Multi-Head Attention。换句话说,Multi-Head Attention为Attention提供了多个“representation subspaces”。因为在每个Attention中,采用不同的Query / Key / Value权重矩阵,每个矩阵都是随机初始化生成的。然后通过训练,将词嵌入投影到不同的“representation subspaces(表示子空间)”中。 

Multi-Head Attention 原理与代码实现_第2张图片

 

        将模型分为多个头,形成多个子空间,可以让模型去关注不同方面的信息。上图中Multi-Head Attention 就是将 Scaled Dot-Product Attention 过程做了8次,再把输出合并起来。

3.Multi-Head Attention 代码实现

from math import sqrt

import torch
import torch.nn as nn

class MultiHeadSelfAttention(nn.Module):
    dim_in: int  # input dimension
    dim_k: int   # key and query dimension
    dim_v: int   # value dimension
    num_heads: int  # number of heads, for each head, dim_* = dim_* // num_heads

    def __init__(self, dim_in, dim_k, dim_v, num_heads=8):
        super(MultiHeadSelfAttention, self).__init__()
        #维度必须能被num_head 整除
        assert dim_k % num_heads == 0 and dim_v % num_heads == 0, "dim_k and dim_v must be multiple of num_heads"
        self.dim_in = dim_in
        self.dim_k = dim_k
        self.dim_v = dim_v
        self.num_heads = num_heads
        #定义线性变换矩阵
        self.linear_q = nn.Linear(dim_in, dim_k, bias=False)
        self.linear_k = nn.Linear(dim_in, dim_k, bias=False)
        self.linear_v = nn.Linear(dim_in, dim_v, bias=False)
        self._norm_fact = 1 / sqrt(dim_k // num_heads)

    def forward(self, x):
        # x: tensor of shape (batch, n, dim_in)
        batch, n, dim_in = x.shape
        assert dim_in == self.dim_in

        nh = self.num_heads
        dk = self.dim_k // nh  # dim_k of each head
        dv = self.dim_v // nh  # dim_v of each head

        q = self.linear_q(x).reshape(batch, n, nh, dk).transpose(1, 2)  # (batch, nh, n, dk)
        k = self.linear_k(x).reshape(batch, n, nh, dk).transpose(1, 2)  # (batch, nh, n, dk)
        v = self.linear_v(x).reshape(batch, n, nh, dv).transpose(1, 2)  # (batch, nh, n, dv)

        dist = torch.matmul(q, k.transpose(2, 3)) * self._norm_fact  # batch, nh, n, n
        dist = torch.softmax(dist, dim=-1)  # batch, nh, n, n

        att = torch.matmul(dist, v)  # batch, nh, n, dv
        att = att.transpose(1, 2).reshape(batch, n, self.dim_v)  # batch, n, dim_v
        return att

你可能感兴趣的:(深度学习,人工智能)