输入hello目标输出ohlol
定义一些数据
import torch
import numpy
batch_size=1
num_layers=1#一层rnn
input_size=4
hidden_size=4
seq_len=5
进行数据准备
idx2char=['e','h','l','o']
x_data=[1,0,2,2,3]
y_data=[3,1,2,3,2]
#查询字典
one_hot_lookup=[[1,0,0,0],
[0,1,0,0],
[0,0,1,0],
[0,0,0,1]]
#x_data转化x_one_hot编码
x_one_hot=[one_hot_lookup[x] for x in x_data]
#input_size维度为(seqlen,batchsize,inputsize)
inputs=torch.Tensor(x_one_hot).view(seq_len,batch_size,input_size)
#labels维度为(seqlen*batchsize,1)
labels=torch.LongTensor(y_data)
模型
class Model(torch.nn.Module):
def __init__(self,input_size,hidden_size,batch_size,num_layers=1):
super(Model,self).__init__()
self.batch_size=batch_size
self.input_size=input_size
self.hidden_size=hidden_size
self.num_layers=num_layers
#rnncell内部激活函数为tanh
self.rnn=torch.nn.RNN(input_size=self.input_size,
hidden_size=self.hidden_size,
num_layers=num_layers )
def forward(self,input):
#h0
hidden=torch.zeros(self.num_layers,self.batch_size,self.hidden_size)
out,_=self.rnn(input,hidden)
#seqlen*batchsize,hiddensize
return out.view(-1,self.hidden_size)
net=Model(input_size,hidden_size,batch_size,num_layers)
损失函数和优化器
criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(net.parameters(),lr=0.1)
开始训练
#RNN中的输入(SeqLen*batchsize*inputsize)
#RNN中的输出(SeqLen*batchsize*hiddensize)
#labels维度 hiddensize*1
for epoch in range(15):
optimizer.zero_grad()
outputs=net(inputs)
loss=criterion(outputs,labels)
loss.backward()
optimizer.step()
_,idx=outputs.max(dim=1)
idx=idx.data.numpy()
print('pre:',''.join([idx2char[x] for x in idx]),end='')
print(',Epoch [%d/15] loss=%.3f' % (epoch+1,loss.item()))
结果
pre: ooooo,Epoch [1/15] loss=1.390 pre: oollo,Epoch [2/15] loss=1.153 pre: ohllh,Epoch [3/15] loss=0.979 pre: ohloh,Epoch [4/15] loss=0.860 pre: ohlol,Epoch [5/15] loss=0.777 pre: ohlol,Epoch [6/15] loss=0.711 pre: ohlol,Epoch [7/15] loss=0.646 pre: ohlol,Epoch [8/15] loss=0.584 pre: ohlol,Epoch [9/15] loss=0.535 pre: ohlol,Epoch [10/15] loss=0.501 pre: ohlol,Epoch [11/15] loss=0.474 pre: ohlol,Epoch [12/15] loss=0.450 pre: ohlol,Epoch [13/15] loss=0.430 pre: ohlol,Epoch [14/15] loss=0.412 pre: ohlol,Epoch [15/15] loss=0.398