PyTorch之循环神经网络

一、实现过程

例子:由输入的字符串“hello”预测输出的字符串为“ohlol”。
代码如下:

import torch
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# parameters
num_class = 4
num_layers = 2
input_size = 4
hidden_size = 8
embedding_size = 10
batch_size = 1
seq_len = 5

idx2char = ['e','h','l','o']
x_data = [[1,0,2,2,3]]   # (batch,seq_len)
y_data = [3,1,2,3,2]     # (batch * seq_len)

inputs = torch.LongTensor(x_data)
labels = torch.LongTensor(y_data)

## embedding
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.emb = torch.nn.Embedding(input_size, embedding_size)
        self.rnn = torch.nn.RNN(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers,
                                batch_first=True)
        self.fc = torch.nn.Linear(hidden_size, num_class)

    def forward(self, x):
        hidden = torch.zeros(num_layers, x.size(0), hidden_size)
        x = self.emb(x)  # (batch,seqLen,embeddingSize)
        x, _ = self.rnn(x, hidden)
        x = self.fc(x)
        return x.view(-1, num_class)
net = Model()

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(),lr=0.05)

epoch_list = []
loss_list = []
for epoch in range(100):
    optimizer.zero_grad()
    outputs = net(inputs)
    loss = criterion(outputs, labels)

    epoch_list.append(epoch+1)
    loss_list.append(loss.item())

    loss.backward()
    optimizer.step()

    _, idx = outputs.max(dim=1)
    idx = idx.data.numpy()
    print('Predicted string:', ''.join([idx2char[x] for x in idx]), end='')
    print(', Epoch [%d/100] loss=%.3f' % (epoch + 1, loss.item()))

# 画图
plt.plot(epoch_list,loss_list,'b')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid()
plt.show()

运行结果为:

Predicted string: eoell, Epoch [1/100] loss=1.381
Predicted string: ollll, Epoch [2/100] loss=1.119
Predicted string: ohlll, Epoch [3/100] loss=0.921
Predicted string: ohlll, Epoch [4/100] loss=0.715
Predicted string: ohlol, Epoch [5/100] loss=0.537
Predicted string: ohlol, Epoch [6/100] loss=0.386
Predicted string: ohlol, Epoch [7/100] loss=0.265
Predicted string: ohlol, Epoch [8/100] loss=0.175
Predicted string: ohlol, Epoch [9/100] loss=0.114
Predicted string: ohlol, Epoch [10/100] loss=0.074
Predicted string: ohlol, Epoch [11/100] loss=0.050
Predicted string: ohlol, Epoch [12/100] loss=0.035
Predicted string: ohlol, Epoch [13/100] loss=0.025
Predicted string: ohlol, Epoch [14/100] loss=0.019
Predicted string: ohlol, Epoch [15/100] loss=0.014
Predicted string: ohlol, Epoch [16/100] loss=0.011
Predicted string: ohlol, Epoch [17/100] loss=0.009
Predicted string: ohlol, Epoch [18/100] loss=0.007
Predicted string: ohlol, Epoch [19/100] loss=0.006
Predicted string: ohlol, Epoch [20/100] loss=0.005
Predicted string: ohlol, Epoch [21/100] loss=0.004
Predicted string: ohlol, Epoch [22/100] loss=0.003
Predicted string: ohlol, Epoch [23/100] loss=0.003
Predicted string: ohlol, Epoch [24/100] loss=0.002
Predicted string: ohlol, Epoch [25/100] loss=0.002
Predicted string: ohlol, Epoch [26/100] loss=0.002
Predicted string: ohlol, Epoch [27/100] loss=0.002
Predicted string: ohlol, Epoch [28/100] loss=0.001
Predicted string: ohlol, Epoch [29/100] loss=0.001
Predicted string: ohlol, Epoch [30/100] loss=0.001
Predicted string: ohlol, Epoch [31/100] loss=0.001
Predicted string: ohlol, Epoch [32/100] loss=0.001
Predicted string: ohlol, Epoch [33/100] loss=0.001
Predicted string: ohlol, Epoch [34/100] loss=0.001
Predicted string: ohlol, Epoch [35/100] loss=0.001
Predicted string: ohlol, Epoch [36/100] loss=0.001
Predicted string: ohlol, Epoch [37/100] loss=0.001
Predicted string: ohlol, Epoch [38/100] loss=0.001
Predicted string: ohlol, Epoch [39/100] loss=0.001
Predicted string: ohlol, Epoch [40/100] loss=0.001
Predicted string: ohlol, Epoch [41/100] loss=0.001
Predicted string: ohlol, Epoch [42/100] loss=0.001
Predicted string: ohlol, Epoch [43/100] loss=0.001
Predicted string: ohlol, Epoch [44/100] loss=0.001
Predicted string: ohlol, Epoch [45/100] loss=0.001
Predicted string: ohlol, Epoch [46/100] loss=0.001
Predicted string: ohlol, Epoch [47/100] loss=0.001
Predicted string: ohlol, Epoch [48/100] loss=0.001
Predicted string: ohlol, Epoch [49/100] loss=0.000
Predicted string: ohlol, Epoch [50/100] loss=0.000
Predicted string: ohlol, Epoch [51/100] loss=0.000
Predicted string: ohlol, Epoch [52/100] loss=0.000
Predicted string: ohlol, Epoch [53/100] loss=0.000
Predicted string: ohlol, Epoch [54/100] loss=0.000
Predicted string: ohlol, Epoch [55/100] loss=0.000
Predicted string: ohlol, Epoch [56/100] loss=0.000
Predicted string: ohlol, Epoch [57/100] loss=0.000
Predicted string: ohlol, Epoch [58/100] loss=0.000
Predicted string: ohlol, Epoch [59/100] loss=0.000
Predicted string: ohlol, Epoch [60/100] loss=0.000
Predicted string: ohlol, Epoch [61/100] loss=0.000
Predicted string: ohlol, Epoch [62/100] loss=0.000
Predicted string: ohlol, Epoch [63/100] loss=0.000
Predicted string: ohlol, Epoch [64/100] loss=0.000
Predicted string: ohlol, Epoch [65/100] loss=0.000
Predicted string: ohlol, Epoch [66/100] loss=0.000
Predicted string: ohlol, Epoch [67/100] loss=0.000
Predicted string: ohlol, Epoch [68/100] loss=0.000
Predicted string: ohlol, Epoch [69/100] loss=0.000
Predicted string: ohlol, Epoch [70/100] loss=0.000
Predicted string: ohlol, Epoch [71/100] loss=0.000
Predicted string: ohlol, Epoch [72/100] loss=0.000
Predicted string: ohlol, Epoch [73/100] loss=0.000
Predicted string: ohlol, Epoch [74/100] loss=0.000
Predicted string: ohlol, Epoch [75/100] loss=0.000
Predicted string: ohlol, Epoch [76/100] loss=0.000
Predicted string: ohlol, Epoch [77/100] loss=0.000
Predicted string: ohlol, Epoch [78/100] loss=0.000
Predicted string: ohlol, Epoch [79/100] loss=0.000
Predicted string: ohlol, Epoch [80/100] loss=0.000
Predicted string: ohlol, Epoch [81/100] loss=0.000
Predicted string: ohlol, Epoch [82/100] loss=0.000
Predicted string: ohlol, Epoch [83/100] loss=0.000
Predicted string: ohlol, Epoch [84/100] loss=0.000
Predicted string: ohlol, Epoch [85/100] loss=0.000
Predicted string: ohlol, Epoch [86/100] loss=0.000
Predicted string: ohlol, Epoch [87/100] loss=0.000
Predicted string: ohlol, Epoch [88/100] loss=0.000
Predicted string: ohlol, Epoch [89/100] loss=0.000
Predicted string: ohlol, Epoch [90/100] loss=0.000
Predicted string: ohlol, Epoch [91/100] loss=0.000
Predicted string: ohlol, Epoch [92/100] loss=0.000
Predicted string: ohlol, Epoch [93/100] loss=0.000
Predicted string: ohlol, Epoch [94/100] loss=0.000
Predicted string: ohlol, Epoch [95/100] loss=0.000
Predicted string: ohlol, Epoch [96/100] loss=0.000
Predicted string: ohlol, Epoch [97/100] loss=0.000
Predicted string: ohlol, Epoch [98/100] loss=0.000
Predicted string: ohlol, Epoch [99/100] loss=0.000
Predicted string: ohlol, Epoch [100/100] loss=0.000

PyTorch之循环神经网络_第1张图片

二、参考文献

[1] https://www.bilibili.com/video/BV1Y7411d7Ys?p=12

你可能感兴趣的:(机器学习,pytorch,深度学习,RNN,embedding)