B站up主“刘二大人”视频 笔记
本文章是该视频的一部分,该部分的案例代码使用RNN做一个简单的实验,其余部分见作者的其他文章。
一、什么是循环神经网络
循环神经网络出现于20世纪 80年代,最近由于网络设计的推进和图形处理单元上计算能力的提升,循环神经网络变得越来越流行。这种网络尤其是对序列数据非常有用,因为每个神经元或者单元能用它的内部存储来保存之前输入的相关信息。在语言的案例中,“I had washed my house”这句话的意思与“I had my house washed”大不相同。这就能让网络获取对该表达更深的理解。
注意到这点很重要,因为当阅读一个句子甚至是一个人时,你就是要从它之前的单词中提出每个词的语境。
一个循环神经网络里有很多个环,这些环能允许带着信息通过神经元,同时在输入中读取它们。
二、循环神经网络能干什么
RNN 有很多应用。一个不错的应用是与自然语言处理(NLP)的合作。网上已经有很多人证明了 RNN,他们创造出了令人惊讶的模型,这些模型能表示一种语言模型。这些语言模型能采纳像莎士比亚的诗歌这样的大量输入,并在训练这些模型后生成它们自己的莎士比亚式的诗歌,而且这些诗歌很难与原作区分开来。
另一个让人惊喜的 RNN 应用是机器翻译。这种方法很有趣,因为它需要同时训练两个 RNN。在这些网络中,输入的是成对的不同语言的句子。例如,你能给这个网络输入意思相同的一对英法两种语言的句子,其中英语是源语言,法语作为翻译语言。有了足够的训练后,你给这个网络一个英语句子,它就能把它翻译成法语!这个模型被称为序列到序列模型(Sequence to Sequences model )或者编码-解码模型(Encoder- Decoder model)。
三、RNN使用案例,代码如下:
import torch
num_class = 4
input_size = 4
hidden_size = 8
embedding_size = 10
num_layers = 2
batch_size = 1
seq_len = 5
idx2char = ['e', 'h', 'l', 'o']
x_data = [[1, 0, 2, 2, 3]] # hello
y_data = [3, 1, 2, 3, 2] # ohlol
inputs = torch.LongTensor(x_data)
labels = torch.LongTensor(y_data)
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.emb = torch.nn.Embedding(input_size, embedding_size) # 嵌入层
self.rnn = torch.nn.RNN(input_size=embedding_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True)
self.fc = torch.nn.Linear(hidden_size, num_class)
def forward(self, x):
hidden = torch.zeros(num_layers, x.size(0), hidden_size) # 构造h0
x = self.emb(x) # 把长整型转变成嵌入层稠密的向量模式
x, _ = self.rnn(x, hidden)
x = self.fc(x)
return x.view(-1, num_class)
net = Model()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.05)
for epoch in range(15):
optimizer.zero_grad() # 优化器归零
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward() # 反向传播
optimizer.step() # 优化器更新
_, idx = outputs.max(dim=1)
idx = idx.data.numpy()
print('Predicted: ', ''.join([idx2char[x] for x in idx]), end='')
print(', Epoch [%d/15] loss=%.3f ' % (epoch + 1, loss.item()))
运行结果如下:
视频截图如下:
独热向量(one-hot)缺点:1、向量维度太高;2、向量稀疏;3、向量是硬编码(不是学习出来的)