rnn简单实现

输入hello目标输出ohlol

定义一些数据

import torch
import numpy
batch_size=1
num_layers=1#一层rnn
input_size=4
hidden_size=4
seq_len=5

进行数据准备

idx2char=['e','h','l','o']
x_data=[1,0,2,2,3]
y_data=[3,1,2,3,2]
#查询字典
one_hot_lookup=[[1,0,0,0],
                [0,1,0,0],
                [0,0,1,0],
                [0,0,0,1]]
#x_data转化x_one_hot编码
x_one_hot=[one_hot_lookup[x] for x in x_data]
#input_size维度为(seqlen,batchsize,inputsize)
inputs=torch.Tensor(x_one_hot).view(seq_len,batch_size,input_size)
#labels维度为(seqlen*batchsize,1)
labels=torch.LongTensor(y_data)

模型

class Model(torch.nn.Module):
    def __init__(self,input_size,hidden_size,batch_size,num_layers=1):
        super(Model,self).__init__()
        self.batch_size=batch_size
        self.input_size=input_size
        self.hidden_size=hidden_size
        self.num_layers=num_layers
        #rnncell内部激活函数为tanh
        self.rnn=torch.nn.RNN(input_size=self.input_size,
                              hidden_size=self.hidden_size,
                              num_layers=num_layers )
    def forward(self,input):
        #h0
        hidden=torch.zeros(self.num_layers,self.batch_size,self.hidden_size)
        out,_=self.rnn(input,hidden)
        #seqlen*batchsize,hiddensize
        return out.view(-1,self.hidden_size)
    
net=Model(input_size,hidden_size,batch_size,num_layers)

 损失函数和优化器

criterion=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(net.parameters(),lr=0.1)

开始训练

#RNN中的输入(SeqLen*batchsize*inputsize)
#RNN中的输出(SeqLen*batchsize*hiddensize)
#labels维度 hiddensize*1
for epoch in range(15):
    optimizer.zero_grad()
    outputs=net(inputs)
    loss=criterion(outputs,labels)
    loss.backward()
    optimizer.step()
    _,idx=outputs.max(dim=1)
    idx=idx.data.numpy()
    print('pre:',''.join([idx2char[x] for x in idx]),end='')
    print(',Epoch [%d/15] loss=%.3f' % (epoch+1,loss.item()))

结果

pre: ooooo,Epoch [1/15] loss=1.390
pre: oollo,Epoch [2/15] loss=1.153
pre: ohllh,Epoch [3/15] loss=0.979
pre: ohloh,Epoch [4/15] loss=0.860
pre: ohlol,Epoch [5/15] loss=0.777
pre: ohlol,Epoch [6/15] loss=0.711
pre: ohlol,Epoch [7/15] loss=0.646
pre: ohlol,Epoch [8/15] loss=0.584
pre: ohlol,Epoch [9/15] loss=0.535
pre: ohlol,Epoch [10/15] loss=0.501
pre: ohlol,Epoch [11/15] loss=0.474
pre: ohlol,Epoch [12/15] loss=0.450
pre: ohlol,Epoch [13/15] loss=0.430
pre: ohlol,Epoch [14/15] loss=0.412
pre: ohlol,Epoch [15/15] loss=0.398

你可能感兴趣的:(rnn简单实现)