模型加载Glove的预训练Embedding

模型加载Glove的预训练Embedding_第1张图片

import torch
import torch.nn as nn
import torch.nn.functional as F

class Vocab:
    def __init__(self, tokens = None):
        self.idx_to_token = list()
        self.token_to_idx = dict()

        if tokens is not None:
            if "" not in tokens:
                tokens = tokens + [""]
            for token in tokens:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1
            self.unk = self.token_to_idx[""]

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, token):
        return self.token_to_idx.get(token, self.unk)

    def convert_tokens_to_ids(self, tokens):
        return [self.token_to_idx[token] for token in tokens]

    def covert_ids_to_token(self, indices):
        return [self.idx_to_token[index] for index in indices]


def load_pretrained(load_path):
    tokens = []
    embeds = []
    with open(load_path, 'r', encoding='utf-8') as fin:
        lines = fin.readlines()
        for line in lines:
            line = line.strip().split(' ')
            token = line[0]
            embed = list(map(float, line[1:]))
            tokens.append(token)
            embeds.append(embed)
    vocab = Vocab(tokens)
    embeds = torch.tensor(embeds, dtype=torch.float)

    return vocab, embeds


pt_vocab, pt_embeds = load_pretrained('./glove.6B.50d.txt')

class MLP(nn.Module):
    def __init__(self, pt_vocab, pt_embeddings, hidden_dim, num_class):
        super(MLP, self).__init__()
        embedding_dim = pt_embeddings.shape[1]
        vocab_size = len(pt_vocab)
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.embeddings.weight.data.uniform_(-0.1, 0.1)
        for idx, token in enumerate(pt_vocab.idx_to_token):
            pt_idx = pt_vocab[token]
            if pt_idx != pt_vocab.unk:
                self.embeddings.weight[pt_idx].data.copy_(pt_embeddings[pt_idx])

        self.fc1 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_class)
        self.activate = F.relu

    def forward(self):
        pass

model = MLP(pt_vocab=pt_vocab, pt_embeddings=pt_embeds, hidden_dim=128, num_class=2)

你可能感兴趣的:(自然语言处理,深度学习,python,人工智能)