pytorch 提供了一个简便方法torch.nn.Embedding.from_pretrained
,可以将文本与预训练的embedding对应起来:
词 | embedding |
---|---|
word1 | 0,2,3,4 |
word2 | 1,2,3,4 |
word3 | 2,2,3,4 |
… | … |
使用方法就是:
首先有一个预训练的embedding
列表:
torch.Tensor([
[0, 2, 3, 4],
[1, 2, 3, 4],
[2, 2, 3, 4],
[3, 2, 3, 4], ])
这个顺序与词表的顺序要一致,这样,如果输入一个1
,就意味着我要拿到第1个字的embedding,就是[1,2,3,4]
;
# coding: UTF-8
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self, embedding_pretrained):
super(Model, self).__init__()
self.embedding = nn.Embedding.from_pretrained(embedding_pretrained, freeze=False)
def forward(self, x):
out = self.embedding(x)
return out
if __name__ == '__main__':
# 预训练的 embedding
pre_train = torch.Tensor([[0, 2, 3, 4],
[1, 2, 3, 4],
[2, 2, 3, 4],
[3, 2, 3, 4], ])
model = Model(pre_train)
embedding = model(torch.Tensor([[1, 1, 1, 2, 0], # 第1句话包含的字的编号
[1, 0, 2, 1, 1]]).long()) # 第2句话包含的字的编号