使用 PyTorch 实现 Word2Vec 中Skip-gram 模型

首先创建了一个使用 Word2VecDataset 类自定义的数据集,用于生成训练数据。然后,定义了 Skip-gram 模型,并使用交叉熵损失函数和 Adam 优化器进行训练。

在每个训练周期中,遍历数据加载器,对每个批次进行前向传播、计算损失、反向传播和权重更新。最后,得到训练得到的词向量,并可以使用 word_vector 来获取特定单词的词向量表示。

确保在运行之前安装 PyTorch,可以使用 pip install torch 来安装它。请注意,如果可用的话,代码将在 GPU 上运行。如果没有 GPU,请将 .to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) 部分删除,并在 CPU 上运行。

以下是使用 PyTorch 实现 Skip-gram 模型的示例代码:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Hyperparameters
embedding_dim = 100
window_size = 2
learning_rate = 0.001
epochs = 100
batch_size = 32

# Example corpus
corpus = [['I', 'enjoy', 'playing', 'football', 'with', 'my', 'friends'],
          ['We', 'like', 'to', 'play', 'tennis', 'on', 'weekends'],
          ['She', 'is', 'a', 'good', 'dancer']]

# Create vocabulary
vocab = list(set([word for sentence in corpus for word in sentence]))
vocab_size = len(vocab)
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for idx, word in enumerate(vocab)}

# Generate training data
class Word2VecDataset(Dataset):
    def __init__(self, corpus, word2idx):
        self.data = []
        for sentence in corpus:
            word_indices = [word2idx[word] for word in sentence]
            for center_word_idx, center_word in enumerate(word_indices):
                for context_word_idx in range(max(0, center_word_idx - window_size), min(center_word_idx + window_size + 1, len(word_indices))):
                    if context_word_idx != center_word_idx:
                        context_word = word_indices[context_word_idx]
                        self.data.append((center_word, context_word))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

dataset = Word2VecDataset(corpus, word2idx)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Define Skip-gram model
class SkipGramModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(SkipGramModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(embedding_dim, vocab_size)
        
    def forward(self, center_word):
        embedded = self.embedding(center_word)
        output = self.linear(embedded)
        return output

model = SkipGramModel(vocab_size, embedding_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training
for epoch in range(epochs):
    running_loss = 0.0
    for i, (center_word, context_word) in enumerate(dataloader):
        optimizer.zero_grad()
        
        center_word = center_word.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        context_word = context_word.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        
        output = model(center_word)
        loss = criterion(output, context_word)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    average_loss = running_loss / len(dataloader)
    print(f'Epoch {epoch+1}/{epochs}, Loss: {average_loss:.4f}')

# Get trained word embeddings
trained_embeddings = model.embedding.weight.data.numpy()

# Example usage - Getting word vector for a word
word = 'football'
word_vector = trained_embeddings[word2idx[word]]
print(f"Word vector for '{word}': {word_vector}")

运行结果如下:

Epoch 1/100, Loss: 3.1324
Epoch 2/100, Loss: 3.0791
Epoch 3/100, Loss: 2.9902
Epoch 4/100, Loss: 2.9392
Epoch 5/100, Loss: 2.8870
Epoch 6/100, Loss: 2.8166
Epoch 7/100, Loss: 2.7615
Epoch 8/100, Loss: 2.7017
Epoch 9/100, Loss: 2.6500
Epoch 10/100, Loss: 2.5993
Epoch 11/100, Loss: 2.5496
Epoch 12/100, Loss: 2.5013
Epoch 13/100, Loss: 2.4621
Epoch 14/100, Loss: 2.4079
Epoch 15/100, Loss: 2.3660
Epoch 16/100, Loss: 2.3229
Epoch 17/100, Loss: 2.2795
Epoch 18/100, Loss: 2.2398
Epoch 19/100, Loss: 2.1998
Epoch 20/100, Loss: 2.1582
Epoch 21/100, Loss: 2.1278
Epoch 22/100, Loss: 2.1023
Epoch 23/100, Loss: 2.0569
Epoch 24/100, Loss: 2.0245
Epoch 25/100, Loss: 1.9936
Epoch 26/100, Loss: 1.9639
Epoch 27/100, Loss: 1.9344
Epoch 28/100, Loss: 1.9137
Epoch 29/100, Loss: 1.8888
Epoch 30/100, Loss: 1.8586
Epoch 31/100, Loss: 1.8352
Epoch 32/100, Loss: 1.8200
Epoch 33/100, Loss: 1.7815
Epoch 34/100, Loss: 1.7685
Epoch 35/100, Loss: 1.7531
Epoch 36/100, Loss: 1.7209
Epoch 37/100, Loss: 1.7049
Epoch 38/100, Loss: 1.6881
Epoch 39/100, Loss: 1.6775
Epoch 40/100, Loss: 1.6517
Epoch 41/100, Loss: 1.6390
Epoch 42/100, Loss: 1.6238
Epoch 43/100, Loss: 1.6077
Epoch 44/100, Loss: 1.5939
Epoch 45/100, Loss: 1.5745
Epoch 46/100, Loss: 1.5703
Epoch 47/100, Loss: 1.5574
Epoch 48/100, Loss: 1.5458
Epoch 49/100, Loss: 1.5308
Epoch 50/100, Loss: 1.5215
Epoch 51/100, Loss: 1.5122
Epoch 52/100, Loss: 1.4988
Epoch 53/100, Loss: 1.4958
Epoch 54/100, Loss: 1.4773
Epoch 55/100, Loss: 1.4746
Epoch 56/100, Loss: 1.4618
Epoch 57/100, Loss: 1.4560
Epoch 58/100, Loss: 1.4506
Epoch 59/100, Loss: 1.4380
Epoch 60/100, Loss: 1.4266
Epoch 61/100, Loss: 1.4257
Epoch 62/100, Loss: 1.4148
Epoch 63/100, Loss: 1.4090
Epoch 64/100, Loss: 1.4070
Epoch 65/100, Loss: 1.3940
Epoch 66/100, Loss: 1.3890
Epoch 67/100, Loss: 1.3846
Epoch 68/100, Loss: 1.3813
Epoch 69/100, Loss: 1.3738
Epoch 70/100, Loss: 1.3717
Epoch 71/100, Loss: 1.3681
Epoch 72/100, Loss: 1.3594
Epoch 73/100, Loss: 1.3593
Epoch 74/100, Loss: 1.3504
Epoch 75/100, Loss: 1.3447
Epoch 76/100, Loss: 1.3439
Epoch 77/100, Loss: 1.3397
Epoch 78/100, Loss: 1.3315
Epoch 79/100, Loss: 1.3260
Epoch 80/100, Loss: 1.3253
Epoch 81/100, Loss: 1.3229
Epoch 82/100, Loss: 1.3215
Epoch 83/100, Loss: 1.3148
Epoch 84/100, Loss: 1.3160
Epoch 85/100, Loss: 1.3072
Epoch 86/100, Loss: 1.3105
Epoch 87/100, Loss: 1.3104
Epoch 88/100, Loss: 1.3018
Epoch 89/100, Loss: 1.2912
Epoch 90/100, Loss: 1.2950
Epoch 91/100, Loss: 1.2938
Epoch 92/100, Loss: 1.2951
Epoch 93/100, Loss: 1.2859
Epoch 94/100, Loss: 1.2902
Epoch 95/100, Loss: 1.2840
Epoch 96/100, Loss: 1.2748
Epoch 97/100, Loss: 1.2840
Epoch 98/100, Loss: 1.2763
Epoch 99/100, Loss: 1.2772
Epoch 100/100, Loss: 1.2746


Word vector for 'football':

[-1.2727762   0.8401019  -0.5115612   2.0667355   1.1854529  -0.7444803
 -1.9658612  -1.0488677   0.98938674 -1.1675086   1.582392    1.7414839
 -0.4892138  -1.2149098   0.15343344 -1.8318586   0.41794038  0.25481498
  0.6008032  -0.23904797  0.80143225 -1.0495795  -1.0174142  -0.01827855
  2.7477944  -0.9574399   1.025569    2.4843202  -0.2796719  -0.4390253
 -1.4423424  -1.8073392   0.1897556   0.90259725  2.7565296  -0.28331178
 -1.8443514   0.77545553 -1.0289538   0.71483964  1.1801128  -0.22635305
  0.5960759   0.6690206  -1.9100318   1.2388043  -0.68522704  0.92120373
  1.0252377  -1.4376261  -0.6595934   0.31699112  0.6751458   0.99656415
  0.40565705 -1.0904227  -0.3513346  -0.66078615  1.1834346  -1.0899751
 -1.4925232  -0.30818892  1.4249563   0.06006899 -3.2386255   0.96192694
 -1.1045157   0.5540482  -1.5388466  -0.8721646   1.1221852   1.6488599
  0.44869688  1.1519432  -1.4588032  -0.04230021 -0.33113605  1.1316347
 -0.7425484  -0.11400439  0.37237874 -0.34573358  0.4140474  -0.04413145
  0.6157635  -1.0094129  -1.2208599  -0.7154122   0.9412035   0.9452426
 -0.0973389  -0.23566085  0.34300375 -0.95858365  0.8764276  -0.5669889
 -1.933235    0.22371146  1.6641699   1.3258857 ]

你可能感兴趣的:(pytorch,word2vec,人工智能)