pytorch lstm时序预测代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import pymysql
import datetime

torch.manual_seed(1)
train_window = 6  # 前六个数预测后一个数值
company_data=list()

conn = pymysql.connect(
    host="127.0.0.1",
    user="root",
    password="123456",
    database= "test"
)
cursor = conn.cursor()

#---------------------模型定义 start-------------------------
class LSTMPred(nn.Module):

    def __init__(self,input_size,hidden_dim):
        super(LSTMPred,self).__init__()
        self.input_dim = input_size
        self.hidden_dim = hidden_dim
        self.lstm = nn.LSTM(input_size,hidden_dim)
        self.hidden2out = nn.Linear(hidden_dim,1)
        self.hidden = self.init_hidden()
    def init_hidden(self):
        return(Variable(torch.zeros(1,1,self.hidden_dim)),
                Variable(torch.zeros(1,1,self.hidden_dim)))
    def forward(self,seq):
        lstm_out,self.hidden = self.lstm(
            seq.view(len(seq),1,-1),self.hidden)
        outdat = self.hidden2out(lstm_out.view(len(seq),-1))
        return outdat
#------------------------自定义函数---------------------------
def create_inout_sequences(input_data, tw):
    inout_seq = []
    L = len(input_data)
    for i in range(L-tw):
        train_seq = input_data[i:i+tw]
        train_label = input_data[i+tw:i+tw+1]
        inout_seq.append((train_seq ,train_label))
    return inout_seq
def ToVariable(x):
    tmp = torch.FloatTensor(x)
    return Variable(tmp)

#----------------------------sql--------------------------
train_data=list()
train_data_val=list()
test_data=list()
test_data_val=list()
endtime='{:%Y-%m-%d %H:%M:%S}'.format(today)
starttime_train='{:%Y-%m-%d %H:%M:%S}'.format(today+datetime.timedelta(hours=-12))
starttime_test='{:%Y-%m-%d %H:%M:%S}'.format(today+datetime.timedelta(hours=-8))
train_sql="SELECT value FROM tablename t where createtime > '"+starttime_train+"' and createtime <='"+endtime+"'  order by createtime asc"
        test_sql="SELECT value FROM tablename where createtime > '"+starttime_test+"' and createtime <='"+endtime+"'  order by createtime asc"
cursor.execute(train_sql)
trainresult = cursor.fetchall()
for inx,res in enumerate(trainresult):
	train_data_val.append(res[0])
cursor.execute(test_sql)
testresult=cursor.fetchall()
for inx,res in enumerate(testresult):
	test_data_val.append(res[0])
train_data=create_inout_sequences(train_data_val,train_window)
test_data=create_inout_sequences(test_data_val,train_window)
#---------------------模型训练-------------------------
model = LSTMPred(1,6)
loss_function = nn.MSELoss()
optimizer = optim.SGD(model.parameters(),lr=0.01)
for epoch in range(10):
	for seq,outs in train_data:
		seq = ToVariable(seq)
		outs = ToVariable(outs)
		optimizer.zero_grad()
		model.hidden = model.init_hidden()
		modout = model(seq)
		loss = loss_function(modout,outs)
		loss.backward()
		optimizer.step()
#--------------------------测试-------------------------------------
predDat = []
for seq,trueVal in test_data:
	seq= ToVariable(seq)
	truVal = ToVariable(trueVal)
    predDat.append(model(seq)[-1].data.numpy()[0])
for inx,val in enumerate(predDat):
	d='{:%Y-%m-%d %H:%M}'.format(today+datetime.timedelta(hours=inx+1))
    insert_sql="replace into trend values("+str(d)+"',null,"+str(val)+")"
    cursor.execute(insert_sql)
    conn.commit()
    print(predDat)

delete_sql="delete from trend where createtime <='"+endtime+"'"
cursor.execute(delete_sql)
conn.commit()
cursor.close()
conn.close()

你可能感兴趣的:(机器学习,pytorch,lstm,深度学习)