import torch
idx2char = ['e', 'h', 'l', 'o']
x_data = [[1, 0, 2, 2, 3]] # The input sequence is 'hello' (batch, seq_len),不同于charater_testRNN 和 BasicRNN 中的方式
y_data = [3, 1, 2, 3, 2] # The output sequence is 'ohlol' (batch * seq_len)
# Embedding层要求 input 和 target 是 LongTensor
inputs = torch.LongTensor(x_data)
labels = torch.LongTensor(y_data)
num_class = 4
input_size = 4
hidden_size = 8
embedding_size = 10
num_layers = 2
batch_size = 1
seq_len = 5
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.emb = torch.nn.Embedding(input_size, embedding_size)
self.rnn = torch.nn.RNN(input_size=embedding_size, hidden_size=hidden_size,
num_layers=num_layers, batch_first=True)
self.fc = torch.nn.Linear(hidden_size, num_class)
def forward(self, x):
hidden = torch.zeros(num_layers, x.size(0), hidden_size)
x = self.emb(x)
x, _ = self.rnn(x, hidden)
x = self.fc(x)
return x.view(-1, num_class)
net = Model()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.05)
for epoch in range(15):
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
_, idx = outputs.max(dim=1)
idx = idx.data.numpy() # 数据转换成 numpy 格式
print('Predicted: ', ''.join([idx2char[x] for x in idx]), end='')
print(', Epocn [%d / 15] loss = %.3f' % (epoch + 1, loss.item()))
运行结果
Predicted: ellel, Epocn [1 / 15] loss = 1.532
Predicted: lllll, Epocn [2 / 15] loss = 1.202
Predicted: ohlll, Epocn [3 / 15] loss = 0.972
Predicted: ohlll, Epocn [4 / 15] loss = 0.763
Predicted: ohlol, Epocn [5 / 15] loss = 0.593
Predicted: ohlol, Epocn [6 / 15] loss = 0.439
Predicted: ohlol, Epocn [7 / 15] loss = 0.312
Predicted: ohlol, Epocn [8 / 15] loss = 0.217
Predicted: ohlol, Epocn [9 / 15] loss = 0.151
Predicted: ohlol, Epocn [10 / 15] loss = 0.105
Predicted: ohlol, Epocn [11 / 15] loss = 0.074
Predicted: ohlol, Epocn [12 / 15] loss = 0.053
Predicted: ohlol, Epocn [13 / 15] loss = 0.039
Predicted: ohlol, Epocn [14 / 15] loss = 0.030
Predicted: ohlol, Epocn [15 / 15] loss = 0.024