Transformer_复现_多头注意力机制

import os
import torch
import torch.nn as nn
from torch.utils.data import  Dataset,DataLoader
from tqdm import tqdm

def read_data(file_path,num=None):
    with open(file_path,"r",encoding="utf-8") as f:
        all_data = f.read().split("\n")

    all_text = []
    all_label = []

    for data in all_data:
        s_data = data.split("\t")
        if len(s_data) != 2:
            continue
        text,label = s_data
        all_text.append(text)
        all_label.append(int(label))
    if num:
        return all_text[:num],all_label[:num]
    else:
        return all_text, all_label

def get_word_2_index(all_data):
    word_2_index = {"[PAD]":0,"[UNK]":1}

    for data in all_data:
        for w in data:
            word_2_index[w] = word_2_index.get(w,len(word_2_index))
    return word_2_index

class MyDataset(Dataset):
    def __init__(self,all_text,all_label):
        self.all_text = all_text
        self.all_label = all_label

    def __len__(self):
        assert len(self.all_text) == len(self.all_label)
        return len(self.all_label)

    def __getitem__(self, index):

        text = self.all_text[index][:512]
        label = self.all_label[index]

        text_index = [word_2_index.get(i,1) for i in text]

        return text_index,label,len(text_index)

def coll_fn(batch_data):

    batch_text_idx,batch_label,batch_len = zip(*batch_data)
    batch_max_len = max(batch_len)

    batch_text_idx_new = []

    for text_idx,label,len_ in zip(batch_text_idx,batch_label,batch_len):
        text_idx = text_idx + [0] * (batch_max_len-len_)
        batch_text_idx_new.append(text_idx)

    return torch.tensor(batch_text_idx_new),torch.tensor(batch_label),torch.tensor(batch_len)

class Positional(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

    def forward(self, x):
        return self.embedding(x)

# befor
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(emb_num, emb_num)
        self.gelu = nn.GELU()

    def forward(self, x):
        return self.gelu(self.linear(x))


class Multi_Head_attention(nn.Module):
    def __init__(self):
        super().__init__()
        self.Q = nn.Linear(embedding_dim,embedding_dim)
        self.K = nn.Linear(embedding_dim,embedding_dim)
        self.V = nn.Linear(embedding_dim,embedding_dim)

    def forward(self, x):
        b,s,n = x.shape
        # x_ = x.reshape(b,s,head_num,-1)
        Q = self.Q(x).reshape(b,s,head_num,-1).transpose(1,2)
        K = self.K(x).reshape(b,s,head_num,-1).transpose(1,2)
        V = self.V(x).reshape(b,s,head_num,-1).transpose(1,2)

        score = torch.softmax(Q @ K.transpose(-1,-2) / 10,dim=-2)
        out = score @ V

        # out1 = out.reshape(b,s,n)
        out = out.transpose(1,2).reshape(b,s,n)
        return out

class Norm(nn.Module):
    def __init__(self):
        super().__init__()
        self.norm = nn.LayerNorm(embedding_dim)

    def forward(self, x):
        return self.norm(x)

class Feed_Forward(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(embedding_dim,feed_num)
        self.gelu = nn.GELU()
        self.linear2 = nn.Linear(feed_num,embedding_dim)
        self.norm = nn.LayerNorm(embedding_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        linear1_out = self.linear1(x)
        gelu_out = self.gelu(linear1_out)
        linear2_out = self.linear2(gelu_out)
        norm_out = self.norm(linear2_out)
        return self.dropout(norm_out)

class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.multi_head_attention = Multi_Head_attention()
        self.norm1 = Norm()
        self.feed_forward = Feed_Forward()
        self.norm2 = Norm()

    def forward(self,x):

        att_x = self.multi_head_attention.forward(x)
        norm1 = self.norm1.forward(att_x)

        adn_out1 = x + norm1

        ff_out = self.feed_forward.forward(adn_out1)
        norm2 = self.norm2.forward(ff_out)

        adn_out2 = adn_out1 + norm2

        return adn_out2

class MyTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size,embedding_dim)
        self.positional_layer = Positional()

        self.encoder = nn.Sequential(*[Block() for _ in range(num_hidden_layers)])

        self.cls = nn.Linear(embedding_dim,num_class)
        self.loss_fun = nn.CrossEntropyLoss()

    def forward(self,inputs,labels=None):

        input_emb= self.embedding(inputs)
        position_emb = self.positional_layer(inputs)

        input_embeddings = (input_emb + position_emb)

        encoder_out = self.encoder.forward(input_embeddings)

        # sentence_f = torch.mean(encoder_out,dim=1)
        sentence_f,_ = torch.max(encoder_out,dim=1)

        pre = self.cls(sentence_f)

        if labels is not None:
            loss = self.loss_fun(pre,labels)
            return loss
        return torch.argmax(pre,dim=-1)


if __name__ == "__main__":
    os.chdir(os.path.dirname(os.path.abspath(__file__)))

    train_text,train_label = read_data(os.path.join("..","data","文本分类","train.txt"))
    test_text,test_label = read_data(os.path.join("..","data","文本分类","test.txt"))
    word_2_index = get_word_2_index(train_text)

    vocab_size = len(word_2_index)

    index_2_word = list(word_2_index)

    batch_size = 200
    epoch = 10
    embedding_dim = 256
    feed_num = 256
    num_hidden_layers = 2
    head_num = 2
    lr = 0.0001
    device = "cuda" if torch.cuda.is_available() else "cpu"
    num_class = len(set(train_label))

    train_dataset = MyDataset(train_text,train_label)
    train_dataloader = DataLoader(train_dataset,batch_size=batch_size,shuffle=False,collate_fn=coll_fn)

    test_dataset = MyDataset(test_text, test_label)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=coll_fn)

    model = MyTransformer().to(device)
    opt = torch.optim.Adam(model.parameters(),lr=lr)

    for e in range(epoch):
        for batch_text_idx,batch_label,batch_len in tqdm(train_dataloader):
            batch_text_idx = batch_text_idx.to(device)
            batch_label = batch_label.to(device)
            loss = model.forward(batch_text_idx,batch_label)

            loss.backward()
            opt.step()
            opt.zero_grad()
        # print(f"loss:{loss:.3f}")

        right_num = 0
        for batch_text_idx,batch_label,batch_len in tqdm(test_dataloader):
            batch_text_idx = batch_text_idx.to(device)
            batch_label = batch_label.to(device)
            pre = model.forward(batch_text_idx)

            right_num += int(torch.sum(batch_label == pre ))

        acc = right_num / len(test_dataset)
        print(f"acc:{acc:.3f}")

你可能感兴趣的:(Transformer_复现_多头注意力机制)