pytorch实现word embedding :torch.nn.Embedding

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
 
word_to_ix = {
     'hello': 0, 'world': 1}	#每个单词用一个数字表示
embeds = nn.Embedding(2, 5)	#定义embedding,2为单词的个数,5位embedding的维度

#获取一个Variable,其值为hello的index
hello_idx = torch.LongTensor([word_to_ix['hello']])
hello_idx = Variable(hello_idx)

#获取Variable对应的embedding,并打印出来
hello_embed = embeds(hello_idx)
print(hello_embed)

你可能感兴趣的:(pytorch学习笔记,python)