Transformer 代码补充

本文是对Transformer - Attention is all you need 论文阅读-CSDN博客以及【李宏毅机器学习】Transformer 内容补充-CSDN博客的补充,是对相关代码的理解。

先说个题外话,在之前李宏毅老师的课程中提到multi-head attention是把得到的qkv分别乘上不同的矩阵,得到更多的qkv。Transformer 代码补充_第1张图片

实际上,这里采用的方法是直接截取,比如这里有两个头,那么q^i就被分成两部分q^{i,1}和q^{i,2}。在BERT Intro-CSDN博客中有解释,也推荐手推transformer_哔哩哔哩_bilibili

self-attention

本节内容是self-attention这个模块的实现,会先从某一个句子开始,先不在乎怎么组装在一起批量的处理,只是单个拆开看看每一个部件是怎么work的。

Transformer 代码补充_第2张图片Transformer 代码补充_第3张图片

现在需要解决的是:

  1. 输入怎么embedding? 
  2. 位置信息怎么保留?
  3. 三个矩阵怎么初始化?

单个句子的attention

输入embedding

sentence = 'Life is short, eat dessert first'

dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print("dictionary: {}".format(dc))

sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
print("sentence, but words have been replaced by index in dictionary: \n{}".format(sentence_int))

torch.manual_seed(123)
# len(sentence.replace(',', '').split()) == 6
# embedded length == 16
embed = torch.nn.Embedding(6, 16)
embedded_sentence = embed(sentence_int).detach()

print("sentence, but word embedded: \n{}".format(embedded_sentence))
dictionary: {'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}
sentence, but words have been replaced by index in dictionary: 
tensor([0, 4, 5, 2, 1, 3])
sentence, but word embedded: 
tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,
          0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465],
        [ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,
         -1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0.6230, -1.3729],
        [-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,
          0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],
        [-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,
          0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],
        [ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,
          2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293]])

Embedding — PyTorch 2.1 documentation

位置embedding

我发现这里似乎没有一个固定的名字,有叫position embedding的,有叫position encoding的,还有positional embedding和positional encoding,排列组合orz

### position embedding
def sinusoid_positional_encoding(length, dimensions):
    # odd position
    # cos(position/100000^{2i/d_model})
    # even position
    # sin(position/100000^{2i/d_model})
    def get_position_angle_vec(position):
        return [position / np.power(10000, 2*(i//2)/dimensions) for i in range(dimensions)]
    
    PE = np.array([get_position_angle_vec(i) for i in range(length)])
    PE[:, 0::2] = np.sin(PE[:, 0::2])
    PE[:, 1::2] = np.sin(PE[:, 1::2])
    return PE
embedded_position = torch.tensor(sinusoid_positional_encoding(6, 16))
print("position embedding: \n{}".format(embedded_position))
position embedding: 
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00],
        [ 8.4147e-01,  8.4147e-01,  3.1098e-01,  3.1098e-01,  9.9833e-02,
          9.9833e-02,  3.1618e-02,  3.1618e-02,  9.9998e-03,  9.9998e-03,
          3.1623e-03,  3.1623e-03,  1.0000e-03,  1.0000e-03,  3.1623e-04,
          3.1623e-04],
        [ 9.0930e-01,  9.0930e-01,  5.9113e-01,  5.9113e-01,  1.9867e-01,
          1.9867e-01,  6.3203e-02,  6.3203e-02,  1.9999e-02,  1.9999e-02,
          6.3245e-03,  6.3245e-03,  2.0000e-03,  2.0000e-03,  6.3246e-04,
          6.3246e-04],
        [ 1.4112e-01,  1.4112e-01,  8.1265e-01,  8.1265e-01,  2.9552e-01,
          2.9552e-01,  9.4726e-02,  9.4726e-02,  2.9996e-02,  2.9996e-02,
          9.4867e-03,  9.4867e-03,  3.0000e-03,  3.0000e-03,  9.4868e-04,
          9.4868e-04],
        [-7.5680e-01, -7.5680e-01,  9.5358e-01,  9.5358e-01,  3.8942e-01,
          3.8942e-01,  1.2615e-01,  1.2615e-01,  3.9989e-02,  3.9989e-02,
          1.2649e-02,  1.2649e-02,  4.0000e-03,  4.0000e-03,  1.2649e-03,
          1.2649e-03],
        [-9.5892e-01, -9.5892e-01,  9.9995e-01,  9.9995e-01,  4.7943e-01,
          4.7943e-01,  1.5746e-01,  1.5746e-01,  4.9979e-02,  4.9979e-02,
          1.5811e-02,  1.5811e-02,  5.0000e-03,  5.0000e-03,  1.5811e-03,
          1.5811e-03]], dtype=torch.float64)

Transformer 代码补充_第4张图片

在手推transformer_哔哩哔哩_bilibili中提到这一方法与傅里叶变换相关(这个细节是我在其他地方没有看到的,记录一下)

初始化权重矩阵

new_embedding =(embedded_position+embedded_sentence).to(torch.float32)
print('add embedding together:\n{}'.format(new_embedding))
torch.manual_seed(123)
​
d = new_embedding.shape[1]
print('embedding dimension:\n{}'.format(d))
​
d_q, d_k, d_v = 24, 24, 28
​
# torch.rand 均匀分布 torch.nn.Parameter 普通的tensor不可训练,转换成可以训练的类型
W_query = torch.nn.Parameter(torch.rand(d_q, d))
W_key = torch.nn.Parameter(torch.rand(d_k, d))
W_value = torch.nn.Parameter(torch.rand(d_v, d))
print('size of query matrix: {}'.format(W_query.shape))
print('size of key matrix: {}'.format(W_key.shape))
print('size of value matrix: {}'.format(W_value.shape))
add embedding together:
tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,
          0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],
        [ 1.3561,  1.8352,  0.0523, -0.7716,  0.0555,  1.7234, -2.2913,  1.1194,
          0.6816,  0.7033, -0.9456, -0.0733, -0.1516,  0.1177,  0.4406, -1.4462],
        [ 1.1646,  0.3597,  1.5954,  1.4184, -0.1961,  0.6879, -0.1536, -1.6840,
         -1.5825, -1.0564,  0.9095, -0.7155, -0.5931, -0.7092,  0.6236, -1.3722],
        [-1.1838,  0.3195, -1.3211,  1.8650, -0.0930, -0.6388, -0.4044, -0.9919,
          0.9105,  1.5842,  0.6361, -0.1660,  0.1013, -0.0905,  0.2672, -0.5841],
        [-0.8338, -1.7773,  0.7846,  1.8713,  1.9704,  1.6905,  1.4015, -0.0748,
          0.5365, -1.5323,  0.9792, -1.1355, -1.1549,  0.3295, -0.6302, -2.8387],
        [-0.0821,  0.6632, -0.4780,  2.1331, -0.7409,  1.7933,  1.2108,  0.2963,
          2.2973, -0.7537, -0.2650,  0.7855, -0.6546, -0.7929,  0.1854,  0.2309]],
       dtype=torch.float64)
embedding dimension:
16
size of query matrix: torch.Size([24, 16])
size of key matrix: torch.Size([24, 16])
size of value matrix: torch.Size([28, 16])

Parameter — PyTorch 2.1 documentation

OK,我们在这里先断一下,整理一下:

此时sequence:new_embedding(来源:word embedding+position embedding)

word embedding:6×16(有6个token,每个token用16维向量表示)

position embedding:6×16(和word embedding大小相同,因为要相加)

Transformer 代码补充_第5张图片

q:24×16

k:24×16

v:28×16

这样后面再计算query的时候就是每个token(1×16)×q(24×16),反正两个得转置一个

计算qkv

x_1 = embedded_sentence[0]
query_1 = W_query.matmul(x_1)
key_1 = W_key.matmul(x_1)
value_1 = W_value.matmul(x_1)

x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)

torch.matmul — PyTorch 2.1 documentation

querys = W_key.matmul(new_embedding.T).T
keys = W_key.matmul(new_embedding.T).T
values = W_value.matmul(new_embedding.T).T

print("querys.shape:", querys.shape)
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)
querys.shape: torch.Size([6, 24])
keys.shape: torch.Size([6, 24])
values.shape: torch.Size([6, 28])

计算attention score

alpha_24 = query_2.dot(keys[4])
print(alpha_24)

比如这里,就是第2个query对第5个key的attention

import torch.nn.functional as F

attention_score = F.softmax(keys.matmul(querys.T) / d_k**0.5, dim=0)
print(attention_score)
tensor([[3.8184e-01, 3.7217e-08, 2.8697e-08, 2.3739e-03, 8.8205e-04, 2.0233e-18],
        [5.1460e-03, 3.1125e-02, 3.3185e-09, 1.9323e-03, 1.3870e-07, 4.5397e-13],
        [1.8988e-04, 1.5880e-10, 9.9998e-01, 6.8005e-04, 8.3661e-03, 2.5704e-26],
        [2.1968e-04, 1.2932e-09, 9.5111e-09, 7.5759e-01, 6.8821e-05, 7.0347e-20],
        [1.5536e-02, 1.7667e-11, 2.2270e-05, 1.3099e-02, 9.9068e-01, 3.1426e-23],
        [5.9707e-01, 9.6887e-01, 1.1463e-12, 2.2433e-01, 5.2653e-07, 1.0000e+00]],
       grad_fn=)

获得context value

context_vector_2 = attention_score[2].matmul(values)
print(context_vector_2)
tensor([-2.8135, -0.2665, -0.1881,  0.4058,  0.8079, -3.1120,  0.5449, -1.2232,
        -0.1618,  0.3803,  0.6926, -0.4669,  0.2446, -0.3647, -0.0034, -2.2524,
        -2.7228, -1.5109, -0.7725, -1.0958, -2.1254,  0.3064,  0.5129, -0.1340,
         0.7020, -2.2086, -1.9595,  0.4520], grad_fn=)
context_vector = attention_score.matmul(values)
print(context_vector)
tensor([[ 2.8488e-01,  6.4077e-01,  1.0665e+00,  5.5947e-01, -3.2868e-01,
          4.2391e-01, -3.2123e-01,  1.0594e-01,  6.5982e-01,  6.1927e-01,
          8.2067e-01,  4.3722e-01,  6.4925e-01,  5.9935e-01,  6.7425e-01,
          3.6706e-01,  5.0318e-01,  9.9682e-02,  1.1377e-01,  1.2804e-01,
          9.1880e-01,  7.6178e-01, -4.2619e-01,  2.5550e-01, -8.1348e-02,
          3.1145e-01,  1.9705e-01,  3.8195e-01],
        [ 3.6250e-02,  3.7593e-02,  8.9476e-02,  9.9750e-02,  9.1430e-02,
          6.2556e-02,  5.8136e-02,  5.5746e-02,  3.5098e-02,  4.1406e-02,
          4.1621e-02,  1.9771e-02,  4.0799e-02, -4.7170e-03,  4.1176e-02,
          4.3792e-02,  6.2029e-02,  5.2132e-02,  7.6929e-03,  5.4507e-02,
          1.4537e-02,  6.9540e-02,  4.1809e-02,  5.8921e-02,  1.2542e-02,
          1.4625e-01,  3.0627e-02,  1.0624e-01],
        [-2.8135e+00, -2.6652e-01, -1.8809e-01,  4.0583e-01,  8.0793e-01,
         -3.1120e+00,  5.4491e-01, -1.2232e+00, -1.6184e-01,  3.8030e-01,
          6.9257e-01, -4.6693e-01,  2.4462e-01, -3.6468e-01, -3.3741e-03,
         -2.2524e+00, -2.7228e+00, -1.5109e+00, -7.7255e-01, -1.0958e+00,
         -2.1254e+00,  3.0638e-01,  5.1293e-01, -1.3400e-01,  7.0203e-01,
         -2.2086e+00, -1.9595e+00,  4.5198e-01],
        [ 1.3995e+00, -5.1583e-02, -7.6128e-01,  6.2276e-01,  1.4197e+00,
         -1.1195e+00,  2.6502e-01,  9.7265e-02, -1.3257e+00,  5.2765e-01,
         -9.0406e-01,  1.0977e+00,  1.0775e+00, -1.1202e+00, -5.3005e-01,
          1.1657e+00,  5.2906e-01, -3.4296e-01, -1.0341e+00, -9.9314e-02,
          2.4160e-01,  1.0506e+00, -2.5196e-01, -1.2585e+00,  7.7441e-01,
         -3.8052e-02,  1.4004e+00,  4.0364e-01],
        [-1.9422e+00, -1.1669e-01,  2.4155e+00, -6.0575e-01,  1.1378e-01,
         -8.1691e-01,  2.8678e-01, -2.6922e+00,  1.9804e+00,  2.7446e+00,
          1.9828e-01, -1.5773e+00, -5.2589e-01,  2.2252e+00, -2.9130e-01,
         -4.2694e+00,  2.4834e+00, -3.3346e+00, -2.5167e-01, -2.8141e+00,
          1.3780e+00, -1.5563e-01, -1.4588e+00,  5.3617e-01, -5.3745e-01,
         -7.6528e-01,  1.2408e+00,  3.5827e+00],
        [ 5.3134e+00,  3.5967e+00,  7.1373e+00,  5.9613e+00,  6.1520e+00,
          5.0065e+00,  4.2107e+00,  5.2589e+00,  9.2143e-01,  6.5614e+00,
          2.7412e+00,  4.6712e+00,  4.9725e+00,  2.2118e+00,  5.2451e+00,
          4.4219e+00,  4.5800e+00,  2.9179e+00,  2.2116e+00,  5.3678e+00,
          5.7133e+00,  7.1016e+00,  3.7317e+00,  5.1325e+00,  4.1306e+00,
          9.4941e+00,  5.6733e+00,  9.7489e+00]], grad_fn=)

参考链接

Positional Encoding: Everything You Need to Know - inovex GmbH
Build your own Transformer from scratch using Pytorch | by Arjun Sarkar | Towards Data Science
Understanding and Coding the Self-Attention Mechanism of Large Language Models From Scratch (sebastianraschka.com)
2021-03-18-Transformers - Multihead Self Attention Explanation & Implementation in Pytorch.ipynb - Colaboratory (google.com)
通俗易懂的理解傅里叶变换(一)[收藏] - 知乎 (zhihu.com)
Linear Relationships in the Transformer’s Positional Encoding - Timo Denk's Blog
Transformer 中的 positional embedding - 知乎 (zhihu.com)
transformer中使用的position embedding为什么是加法? - 知乎 (zhihu.com)

multi-head self-attention

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        # 这里d_k是每个key和query的size,同时在后面归一化也需要使用
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # 计算attention score,Q和K反正得转置一个,看怎么定义
        # 比如现在的attn_scores的第(i,j)位置表示:
        # 第i个query对第k个key的attention(相关性高低)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output

这样就大概看懂了orz。

view

维度变化

randmat = torch.rand((3, 2, 5))
print("view(2, 3, 5): \n{}".format(randmat.view(2,3,5)))
print("view(2, 5, 3): \n{}".format(randmat.view(2,3,5)))
view(2, 3, 5): 
tensor([[[0.8058, 0.3869, 0.7523, 0.1501, 0.1501],
         [0.3409, 0.5355, 0.3474, 0.8371, 0.6785],
         [0.6564, 0.8204, 0.0539, 0.7422, 0.2216]],

        [[0.9450, 0.7839, 0.7118, 0.8868, 0.4249],
         [0.1633, 0.5220, 0.7583, 0.7841, 0.0838],
         [0.4304, 0.5082, 0.3141, 0.1689, 0.0869]]])
view(2, 5, 3): 
tensor([[[0.8058, 0.3869, 0.7523, 0.1501, 0.1501],
         [0.3409, 0.5355, 0.3474, 0.8371, 0.6785],
         [0.6564, 0.8204, 0.0539, 0.7422, 0.2216]],

        [[0.9450, 0.7839, 0.7118, 0.8868, 0.4249],
         [0.1633, 0.5220, 0.7583, 0.7841, 0.0838],
         [0.4304, 0.5082, 0.3141, 0.1689, 0.0869]]])
view(5, 2, 3): 
tensor([[[0.8058, 0.3869, 0.7523],
         [0.1501, 0.1501, 0.3409],
         [0.5355, 0.3474, 0.8371],
         [0.6785, 0.6564, 0.8204],
         [0.0539, 0.7422, 0.2216]],

        [[0.9450, 0.7839, 0.7118],
         [0.8868, 0.4249, 0.1633],
         [0.5220, 0.7583, 0.7841],
         [0.0838, 0.4304, 0.5082],
         [0.3141, 0.1689, 0.0869]]])

transpose

randmat = torch.rand((3, 2, 5))
print(randmat)
print("tanspose(-2,-1): \n{}".format(randmat.transpose(-2,-1)))
print("transpose(1,2): \n{}".format(randmat.transpose(1,2)))
tensor([[[0.3440, 0.9779, 0.9154, 0.6843, 0.9358],
         [0.5081, 0.7446, 0.0274, 0.6329, 0.6427]],

        [[0.6770, 0.6826, 0.2888, 0.8483, 0.9896],
         [0.1457, 0.3154, 0.6381, 0.6555, 0.2204]],

        [[0.4549, 0.0385, 0.1135, 0.8426, 0.8534],
         [0.7915, 0.4030, 0.8209, 0.3390, 0.6290]]])
tanspose(-2,-1): 
tensor([[[0.3440, 0.5081],
         [0.9779, 0.7446],
         [0.9154, 0.0274],
         [0.6843, 0.6329],
         [0.9358, 0.6427]],

        [[0.6770, 0.1457],
         [0.6826, 0.3154],
         [0.2888, 0.6381],
         [0.8483, 0.6555],
         [0.9896, 0.2204]],

        [[0.4549, 0.7915],
         [0.0385, 0.4030],
         [0.1135, 0.8209],
         [0.8426, 0.3390],
         [0.8534, 0.6290]]])
transpose(1,2): 
tensor([[[0.3440, 0.5081],
         [0.9779, 0.7446],
         [0.9154, 0.0274],
         [0.6843, 0.6329],
         [0.9358, 0.6427]],

        [[0.6770, 0.1457],
         [0.6826, 0.3154],
         [0.2888, 0.6381],
         [0.8483, 0.6555],
         [0.9896, 0.2204]],

        [[0.4549, 0.7915],
         [0.0385, 0.4030],
         [0.1135, 0.8209],
         [0.8426, 0.3390],
         [0.8534, 0.6290]]])

参考链接

Build your own Transformer from scratch using Pytorch | by Arjun Sarkar | Towards Data Science
Python numpy.transpose 详解 - 我的明天不是梦 - 博客园 (cnblogs.com)

附录

"""Self-attention module

1. Read the code and explain the following:
    - The nature of the dataset
    - The data flow
    - The shapes of the tensors
    - Why can the attention module be used for this dataset?
2. Create a training loop and evaluate the model according to the instructions
"""
import copy
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm.auto import tqdm


class SampleDataset(Dataset):
    def __init__(
        self,
        size: int = 1024,
        emb_dim: int = 32,
        sequence_length: int = 8,
        n_classes: int = 3,
    ):
        self.embeddings = torch.randn(size, emb_dim)
        self.sequence_length = sequence_length
        self.n_classes = n_classes

    def __len__(self):
        return len(self.embeddings) - self.sequence_length + 1

    def __getitem__(self, idx):
        indices = np.random.choice(
            np.arange(0, len(self.embeddings)), self.sequence_length
        )
        # np.random.shuffle(indices)
        return (
            self.embeddings[indices],  # sequence_length x emb_dim
            torch.tensor(np.max(indices) % self.n_classes),
        )


def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    # The length of the key and value sequences need to be the same
    d_k = query.size(-1)
    # N *
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn


def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])


class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super().__init__()
        assert d_model % heads == 0
        # We assume d_v always equals d_k
        self.d = d_model // heads # d_model: 32 heads: 4
        self.h = heads # h: 4
        self.linears = clones(nn.Linear(d_model, d_model), 4) # 4 identical layers (input: 32, output: 32)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        "Implements Figure 2"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [
            l(x).view(nbatches, -1, self.h, self.d).transpose(1, 2)
            for l, x in zip(self.linears, (query, key, value))
        ] # 4 x 

        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear.
        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d)
        return self.linears[-1](x)


class SequenceClassifier(nn.Module):
    def __init__(self, heads: int = 4, d_model: int = 32, n_classes: int = 3):
        super().__init__()
        self.attention = MultiHeadAttention(heads, d_model)
        self.linear = nn.Linear(d_model, n_classes)

    def forward(self, x):
        # x: N x sequence_length x emb_dim
        x = self.attention(x, x, x)
        x = self.linear(x[:, 0])
        return x


def main(
    n_epochs: int = 1000,
    size: int = 256,
    emb_dim: int = 128,
    sequence_length: int = 8,
    n_classes: int = 3,
):
    dataset = SampleDataset(
        size=size, emb_dim=emb_dim, sequence_length=sequence_length, n_classes=n_classes
    )
    # TODO: create a training loop

    # TODO: Evaluate with the same dataset

    # TODO: Evaluate with a different sequence length (12)


if __name__ == "__main__":
    main()

你可能感兴趣的:(python,深度学习,pytorch,transformer)