import torch
import torch.nn as nn
import torch.optim as optim
import random
class TransformerModel(nn.Module):
def __init__(self, vocab_size, embedding_size, nhead, num_layers, dim_feedforward, dropout):
super(TransformerModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_size)
self.transformer = nn.Transformer(d_model=embedding_size, nhead=nhead, num_encoder_layers=num_layers, num_decoder_layers=num_layers, dim_feedforward=dim_feedforward, dropout=dropout)
self.fc = nn.Linear(embedding_size, vocab_size)
def forward(self, src, tgt):
src_emb = self.embedding(src)
tgt_emb = self.embedding(tgt)
output = self.transformer(src_emb, tgt_emb)
output = self.fc(output)
return output
def train(model, optimizer, criterion, src, tgt):
model.train()
optimizer.zero_grad()
output = model(src, tgt[:-1])
loss = criterion(output.reshape(-1, output.size(-1)), tgt[1:].reshape(-1))
loss.backward()
optimizer.step()
return loss.item()
def generate(seq, model, max_len, temperature=1.0):
with torch.no_grad():
src = torch.tensor(seq, dtype=torch.long).unsqueeze(1)
tgt = torch.tensor([0], dtype=torch.long)
while tgt.item() != 1 and len(tgt) < max_len:
output = model(src, tgt.unsqueeze(1))
output = output[-1, :].squeeze(0) / temperature
prob = nn.functional.softmax(output, dim=-1)
tgt = torch.multinomial(prob, 1)
src = torch.cat((src, tgt.unsqueeze(1)), dim=0)
return src.squeeze().tolist()
vocab = ['[PAD]', '[EOS]', '我', '喜欢', '吃', 'pizza', '你', '喜欢', '什么', '食物']
train_data = [
{'src': [2, 4, 6, 1], 'tgt': [1, 7, 8, 9, 1]},
{'src': [2, 3, 4, 5, 1], 'tgt': [1, 6, 1]},
{'src': [2, 6, 4, 1], 'tgt': [1, 7, 8, 9, 1]},
]
model = TransformerModel(len(vocab), 32, 4, 2, 64, 0.1)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=0)
for epoch in range(1000):
random.shuffle(train_data)
total_loss = 0
for data in train_data:
src = torch.tensor(data['src'], dtype=torch.long)
tgt = torch.tensor(data['tgt'], dtype=torch.long)
loss = train(model, optimizer, criterion, src, tgt)
total_loss += loss
print('Epoch {}, Loss {:.4f}'.format(epoch+1, total_loss/len(train_data)))
a=generate([2, 6, 4], model, 10) # [2, 6, 4, 5, 1]
b=generate([2, 4, 6], model, 10) # [2, 4, 6, 1]
for j in [0,1,2]:
for i in train_data[j]['src']:
print(vocab[i], end='')
print(' ')
for i in train_data[j]['tgt']:
print(vocab[i], end='')
print(' ')
for i in a:
print(vocab[i], end='')
print(' ')
for i in b:
print(vocab[i], end='')