随记·手撕coding | MultiheadAttention

无聊撕一下多头注意力吧~:qkv过完QKV线性层,按头切割,过attention,按头拼接,过fc融合即可输出。

import torch
import torch.nn as nn

class MultiheadAttention(nn.Module):
    # n_heads:多头注意力的数量
    # hid_dim:每个词输出的向量维度
    def __init__(self, hid_dim, n_heads

你可能感兴趣的:(算法岗面试,attention,transformer,自然语言处理,大模型,人工智能,深度学习)