无聊撕一下多头注意力吧~:qkv过完QKV线性层,按头切割,过attention,按头拼接,过fc融合即可输出。
import torch import torch.nn as nn class MultiheadAttention(nn.Module): # n_heads:多头注意力的数量 # hid_dim:每个词输出的向量维度 def __init__(self, hid_dim, n_heads