from models import model
from dataset import Mydata
import torch
import warnings
warnings.filterwarnings("ignore")
path=input("请输入数据集路径:")
epochs=input("请输入训练次数(整数):")
mydata = Mydata(path=path)
dataloader = torch.utils.data.DataLoader(dataset=mydata, batch_size=10, shuffle=False)
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model=model.to(device)
#定义优化器和损失函数
optimizer=torch.optim.SGD(model.parameters(),lr=0.001)
Loss=torch.nn.CrossEntropyLoss()
for epoch in range(int(epochs)):
epoch += 1
# 前向传播
model.train()
for i,(d,l) in enumerate(dataloader):
# print(i)
d=d.to(device)
l = l.to(device)
# 梯度清零
optimizer.zero_grad()
out = model(d)
# 计算损失
loss=Loss(out,l)
# print(loss)
# 反响传播
loss.backward()
# 更新权重
optimizer.step()
if i % 50 == 0:
print('第{}个epoch: loss: {}'.format(epoch,loss.item()))
torch.save(model.state_dict(),"mymodel.pth")
#load
# model = ModelClass()##初始化你的网络
# model_state_dict = torch.load("yourpath.pth")
# model.load_state_dict(model_state_dict)
models.py
import torch
from torch import nn, einsum
import torch.nn.functional as F
import torch
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torchsummary import summary
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads # 64 x 8
self.heads = heads # 8
self.scale = dim_head ** -0.5
#(b,50,128) => (b,50,64*3*8)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask = None):
b, n, _, h = *x.shape, self.heads # n=50,_=128,h=8
# self.to_qkv(x)得到的尺寸为[b,50,64x8x3],然后chunk成3份
# 也就是说,qkv是一个三元tuple,每一份都是[b,50,64x8]的大小
qkv = self.to_qkv(x).chunk(3, dim = -1)
# 把每一份从[b,50,64x8]变成[b,8,50,64]的形式
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
# 这一步不太好理解,q和k都是[b,8,50,64]的形式,50理解为特征数量,64为特征变量
# dots.shape=[b,8,50,50]
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
# 不考虑mask这一块的内容
mask_value = -torch.finfo(dots.dtype).max
if mask is not None:
mask = F.pad(mask.flatten(1), (1, 0), value = True)
assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
mask = mask[:, None, :] * mask[:, :, None]
dots.masked_fill_(~mask, mask_value)
del mask
# 对[b,8,50,50]的最后一个维度做softmax
attn = dots.softmax(dim=-1)
# 这个attn就是计算出来的自注意力值,和v做点乘,out.shape=[b,8,50,64]
out = torch.einsum('bhij,bhjd->bhid', attn, v)
# out.shape变成[b,50,8x64]
out = rearrange(out, 'b h n d -> b n (h d)')
# out.shape重新变成[b,60,128]
out = self.to_out(out)
return out
class PreNorm(nn.Module):
def __init__(self,dim,fn):
super(PreNorm, self).__init__()
self.norm=nn.LayerNorm(dim)
self.fn=fn
def forward(self,x,**kwargs):
aa=self.norm(x)
return self.fn(aa,**kwargs)
class Residual(nn.Module):
def __init__(self,fn):
super(Residual, self).__init__()
self.fn=fn
def forward(self,x,**kwargs):
return self.fn(x,**kwargs)+x
class FeedForward(nn.Module):
def __init__(self,dim,hidden_dim,dropout=0.):
super(FeedForward, self).__init__()
self.net=nn.Sequential(
nn.Linear(dim,hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim,dim),
nn.Dropout(dropout)
)
def forward(self,x):
return self.net(x)
class Transformer(nn.Module):
def __init__(self,dim,depth,heads,dim_head,mlp_dim,dropout):
super(Transformer, self).__init__()
self.fn=nn.Linear(8,dim,bias=False)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.layers=nn.ModuleList([])
for i in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim,Attention(dim,heads=heads,dim_head=dim_head,dropout=dropout))),
Residual(PreNorm(dim,FeedForward(dim,mlp_dim,dropout=dropout)))
]))
# for attn, ff in self.layers:
# print("*" * 20)
# print(attn)
# self.out=nn.Linear()
self.out=nn.Sequential(
nn.LayerNorm(16),
nn.Linear(16,8,bias=False)
)
def liners(self,x,a,b):
cc=nn.Linear(a,b,bias=False)(x)
return cc
def forward(self,x,mask=None):
x=self.fn(x)
cls_token=self.cls_token.repeat(x.shape[0],1,1)
x=torch.cat((cls_token,x),dim=1)
for attn ,ff in self.layers:
x=attn(x,mask=mask)
x=ff(x)
y=x[:,0,:]
# y=self.liners(y,y.shape[1],8)
y=self.out(y)
return y
model=Transformer(dim=16,depth=10,heads=8,dim_head=25,mlp_dim=16,dropout=.2)
if __name__ == '__main__':
print(summary(model,input_size=(60,8),batch_size=10))
#print(model)
dataset.py
import numpy as np
from torch import nn
import torch
from torch.utils.data import DataLoader
class Mydata(torch.utils.data.DataLoader):
def __init__(self,path='data.txt'):
f = open(path, 'r', encoding='utf-8')
datas = f.read()
f.close()
datas=datas.replace(',',',')
f2 = open('font_num.txt', 'w')
datas = datas.strip().split(',')
for i in datas:
i=i.strip()
if i == '金星':
f2.write('1')
elif i == "木星":
f2.write('2')
elif i == "水星":
f2.write('3')
elif i == "火星":
f2.write('4')
elif i == "土星":
f2.write('5')
elif i == "天王星":
f2.write('6')
elif i == "海王星":
f2.write('7')
elif i == "开普勒":
f2.write('8')
f2.close()
f3=open('font_num.txt','r')
imgs_lable=[]
all_data=f3.read()
for i,index in enumerate(range(61,len(all_data))):
# print(i,' to ',index)
# print(all_data[i:index])
img=[int(nn) for nn in all_data[i:index]]
imgs_lable.append(img)
self.data=imgs_lable
self.eye=np.eye(8).astype('float32')
def __getitem__(self, item):
dd=self.data[item][:60]
ll=self.data[item][60]-1
dd=np.array(dd)-1
dd=self.eye[dd]
return dd,ll
def __len__(self):
return len(self.data)
if __name__ == '__main__':
mydata = Mydata(path='data2.txt')
dataloader = torch.utils.data.DataLoader(dataset=mydata, batch_size=10, shuffle=False)
for date ,labe in dataloader:
print(date.shape,labe)
下面是预测代码,pred.py
import numpy as np
from models import model
import torch
import torch.nn.functional as F
import matplotlib.pylab as plt
# f=open('font_num.txt','r')
# lab=(f.read()).strip()
# lab=[int(i)-1 for i in lab]
#
#
model.eval()
model.load_state_dict(torch.load('mymodel.pth'))
eye=torch.eye(8)
def predict(inp):
# for i in range(140):
#
# inp=eye[new[i:60+i]]
# inp = torch.unsqueeze(inp, dim=0)
#
# out = model(inp)
# out2=torch.argmax(out,dim=-1)
# new.append(out2.item())
#
# plt.figure('lable')
# plt.plot(lab[:200])
#
# plt.figure('pred')
# plt.plot(new)
# plt.show()
inp-=1
inp=eye[inp]
inp=torch.unsqueeze(inp,dim=0)
out=model(inp)
out=F.softmax(out,dim=-1)
# print(out)
out2,indice=torch.sort(out,dim=-1,descending=True)
out2=(torch.squeeze(out2)).detach().numpy()
indice = (torch.squeeze(indice)).detach().numpy()
outcom=[((round(out2[0],3)),indice[0]),(round(out2[1],3),indice[1]),(round(out2[2],3),indice[2])]
# print(out2)
# print(outcom)
return outcom
if __name__ == '__main__':
inp=np.ones(60).astype('int32')
out=predict(inp)
print(out)
下面是程序的主入口main.py,直接运行该代码即可弹出界面,实现预测
import numpy as np
from PyQt5.QtWidgets import *
from PyQt5 import QtWidgets
from PyQt5.QtGui import QPixmap,QImage
from PyQt5 import QtGui
from untitled import Ui_Form
from PyQt5.QtWidgets import QFileDialog
import cv2
import sys
from pred import predict
import warnings
warnings.filterwarnings("ignore")
class My(QtWidgets.QWidget,Ui_Form):
def __init__(self):
super(My,self).__init__()
self.setupUi(self)
self.use_palette()
self.pushButton.clicked.connect(self.start)
self.pushButton_2.clicked.connect(self.continues)
self.pushButton_3.clicked.connect(self.q3)
self.pushButton_4.clicked.connect(self.q4)
self.pushButton_5.clicked.connect(self.q5)
self.pushButton_6.clicked.connect(self.q6)
self.pushButton_7.clicked.connect(self.q7)
self.pushButton_8.clicked.connect(self.q8)
self.pushButton_9.clicked.connect(self.q9)
self.pushButton_10.clicked.connect(self.q10)
self.textEdit.insertPlainText('金星,木星,水星,火星,土星,天王星,海王星,开普勒,金星,木星,水星,火星,土星,天王星,海王星,开普勒,金星,木星,水星,火星,土星,天王星,海王星,开普勒,'
'金星,木星,水星,火星,土星,天王星,海王星,开普勒,金星,木星,水星,火星,土星,天王星,海王星,开普勒,金星,木星,水星,火星,土星,天王星,海王星,开普勒,'
'金星,木星,水星,火星,土星,天王星,海王星,开普勒,金星,木星,水星,火星')
self.name = ['金星', '木星', '水星', '火星', '土星', '天王星', '海王星', '开普勒']
def use_palette(self):
self.setWindowTitle("时序预测")
window_pale = QtGui.QPalette()
window_pale.setBrush(self.backgroundRole(), QtGui.QBrush(QtGui.QPixmap("123.jpg")))
self.setPalette(window_pale)
def start(self):
data=self.textEdit.toPlainText()
# data = self.textEdit.toPlainText()
data = data.replace("\n", "")
data = data.split(',')
# data=data.replace("\n","")
# if len(data)!=60:
# msg_box = QMessageBox(QMessageBox.Warning, '警告', '数据有{}个,必须输入60个'.format(len(data)))
# msg_box.exec_()
# # print(data)
# return
#
# datas=[]
# for i in data:
# if i in ['1','2','3','4','5','6','7','8']:
# datas.append(int(i))
#
# else:
# msg_box = QMessageBox(QMessageBox.Warning, '警告', '数据中必须是1-8的整数')
# msg_box.exec_()
# return
#
# datas = np.array(datas)
datas=self.verify(data)
if datas == '0':
return
outcom=predict(datas)
self.lineEdit.setText(str(outcom[0][0]))
self.lineEdit_4.setText(self.name[outcom[0][1]])
self.lineEdit_2.setText(str(outcom[1][0]))
self.lineEdit_5.setText(self.name[outcom[1][1]])
self.lineEdit_3.setText(str(outcom[2][0]))
self.lineEdit_6.setText(self.name[outcom[2][1]])
pass
def continues(self):
aa=self.textEdit.toPlainText()
aa = aa.replace("\n", "")
aa = aa.split(',')[1:]
b=self.lineEdit_4.text()
aa.append(b)
now_data = ",".join(aa)
self.textEdit.setText(now_data)
# print(len(new_data))
# self.textEdit.setText(new_data)
datas=self.verify(aa)
if datas == '0':
return
outcom=predict(datas)
self.lineEdit.setText(str(outcom[0][0]))
self.lineEdit_4.setText(self.name[outcom[0][1]])
self.lineEdit_2.setText(str(outcom[1][0]))
self.lineEdit_5.setText(self.name[outcom[1][1]])
self.lineEdit_3.setText(str(outcom[2][0]))
self.lineEdit_6.setText(self.name[outcom[2][1]])
# print(int(aa))
pass
def verify(self,data):
data2=[]
for i in data:
if i=='金星':
data2.append(1)
elif i=="木星":
data2.append(2)
elif i=="水星":
data2.append(3)
elif i=="火星":
data2.append(4)
elif i=="土星":
data2.append(5)
elif i=="天王星":
data2.append(6)
elif i=="海王星":
data2.append(7)
elif i=="开普勒":
data2.append(8)
if len(data2) != 60:
msg_box = QMessageBox(QMessageBox.Warning, '警告', '数据有{}个,必须输入60个'.format(len(data)))
msg_box.exec_()
# print(data)
return '0'
# datas = []
# for i in data:
# if i in ['1', '2', '3', '4', '5', '6', '7', '8']:
# datas.append(int(i))
#
# else:
# msg_box = QMessageBox(QMessageBox.Warning, '警告', '数据中必须是1-8的整数')
# msg_box.exec_()
# return '0'
datas = np.array(data2)
return datas
def q3(self):
d=self.textEdit.toPlainText()
if d:
dd=","
else:dd=""
self.textEdit.insertPlainText(dd+self.name[0])
def q4(self):
d = self.textEdit.toPlainText()
if d:
dd = ","
else:
dd = ""
self.textEdit.insertPlainText(dd+self.name[1])
def q5(self):
d = self.textEdit.toPlainText()
if d:
dd = ","
else:
dd = ""
self.textEdit.insertPlainText(dd+self.name[2])
def q6(self):
d = self.textEdit.toPlainText()
if d:
dd = ","
else:
dd = ""
self.textEdit.insertPlainText(dd+self.name[3])
def q7(self):
d = self.textEdit.toPlainText()
if d:
dd = ","
else:
dd = ""
self.textEdit.insertPlainText(dd+self.name[4])
def q8(self):
d = self.textEdit.toPlainText()
if d:
dd = ","
else:
dd = ""
self.textEdit.insertPlainText(dd+self.name[5])
def q9(self):
d = self.textEdit.toPlainText()
if d:
dd = ","
else:
dd = ""
self.textEdit.insertPlainText(dd+self.name[6])
def q10(self):
d = self.textEdit.toPlainText()
if d:
dd = ","
else:
dd = ""
self.textEdit.insertPlainText(dd+self.name[7])
if __name__ == '__main__':
app = QtWidgets.QApplication(sys.argv)
a=My()
a.show()
sys.exit(app.exec_())
下面是untitled.py代码
# -*- coding: utf-8 -*-
# Form implementation generated from reading ui file 'untitled.ui'
#
# Created by: PyQt5 UI code generator 5.15.4
#
# WARNING: Any manual changes made to this file will be lost when pyuic5 is
# run again. Do not edit this file unless you know what you are doing.
from PyQt5 import QtCore, QtGui, QtWidgets
class Ui_Form(object):
def setupUi(self, Form):
Form.setObjectName("Form")
Form.resize(677, 564)
self.textEdit = QtWidgets.QTextEdit(Form)
self.textEdit.setGeometry(QtCore.QRect(40, 80, 351, 311))
font = QtGui.QFont()
font.setFamily("Arial")
font.setPointSize(14)
font.setBold(True)
font.setWeight(75)
self.textEdit.setFont(font)
self.textEdit.setStyleSheet("background-color: qconicalgradient(cx:0, cy:0, angle:135, stop:0 rgba(255, 255, 0, 69), stop:0.375 rgba(255, 255, 0, 69), stop:0.423533 rgba(251, 255, 0, 145), stop:0.45 rgba(247, 255, 0, 208), stop:0.477581 rgba(255, 244, 71, 130), stop:0.518717 rgba(255, 218, 71, 130), stop:0.55 rgba(255, 255, 0, 255), stop:0.57754 rgba(255, 203, 0, 130), stop:0.625 rgba(255, 255, 0, 69), stop:1 rgba(255, 255, 0, 69));")
self.textEdit.setObjectName("textEdit")
self.label = QtWidgets.QLabel(Form)
self.label.setGeometry(QtCore.QRect(40, 30, 161, 31))
font = QtGui.QFont()
font.setFamily("Agency FB")
font.setPointSize(18)
font.setBold(True)
font.setWeight(75)
self.label.setFont(font)
self.label.setObjectName("label")
self.label_2 = QtWidgets.QLabel(Form)
self.label_2.setGeometry(QtCore.QRect(520, 150, 111, 31))
font = QtGui.QFont()
font.setFamily("Agency FB")
font.setPointSize(18)
font.setBold(True)
font.setWeight(75)
self.label_2.setFont(font)
self.label_2.setObjectName("label_2")
self.pushButton = QtWidgets.QPushButton(Form)
self.pushButton.setGeometry(QtCore.QRect(170, 480, 91, 41))
self.pushButton.setObjectName("pushButton")
self.pushButton_2 = QtWidgets.QPushButton(Form)
self.pushButton_2.setGeometry(QtCore.QRect(400, 480, 91, 41))
self.pushButton_2.setObjectName("pushButton_2")
self.lineEdit = QtWidgets.QLineEdit(Form)
self.lineEdit.setGeometry(QtCore.QRect(420, 200, 101, 41))
font = QtGui.QFont()
font.setFamily("Arial")
font.setPointSize(20)
font.setBold(True)
font.setWeight(75)
self.lineEdit.setFont(font)
self.lineEdit.setObjectName("lineEdit")
self.lineEdit_2 = QtWidgets.QLineEdit(Form)
self.lineEdit_2.setGeometry(QtCore.QRect(420, 260, 101, 41))
font = QtGui.QFont()
font.setFamily("Arial")
font.setPointSize(20)
font.setBold(True)
font.setWeight(75)
self.lineEdit_2.setFont(font)
self.lineEdit_2.setObjectName("lineEdit_2")
self.lineEdit_3 = QtWidgets.QLineEdit(Form)
self.lineEdit_3.setGeometry(QtCore.QRect(420, 320, 101, 41))
font = QtGui.QFont()
font.setFamily("Arial")
font.setPointSize(20)
font.setBold(True)
font.setWeight(75)
self.lineEdit_3.setFont(font)
self.lineEdit_3.setObjectName("lineEdit_3")
self.label_3 = QtWidgets.QLabel(Form)
self.label_3.setGeometry(QtCore.QRect(420, 150, 61, 31))
font = QtGui.QFont()
font.setFamily("Agency FB")
font.setPointSize(18)
font.setBold(True)
font.setWeight(75)
self.label_3.setFont(font)
self.label_3.setObjectName("label_3")
self.lineEdit_4 = QtWidgets.QLineEdit(Form)
self.lineEdit_4.setGeometry(QtCore.QRect(540, 200, 101, 41))
font = QtGui.QFont()
font.setFamily("Arial")
font.setPointSize(20)
font.setBold(True)
font.setWeight(75)
self.lineEdit_4.setFont(font)
self.lineEdit_4.setObjectName("lineEdit_4")
self.lineEdit_5 = QtWidgets.QLineEdit(Form)
self.lineEdit_5.setGeometry(QtCore.QRect(540, 260, 101, 41))
font = QtGui.QFont()
font.setFamily("Arial")
font.setPointSize(20)
font.setBold(True)
font.setWeight(75)
self.lineEdit_5.setFont(font)
self.lineEdit_5.setObjectName("lineEdit_5")
self.lineEdit_6 = QtWidgets.QLineEdit(Form)
self.lineEdit_6.setGeometry(QtCore.QRect(540, 320, 101, 41))
font = QtGui.QFont()
font.setFamily("Arial")
font.setPointSize(20)
font.setBold(True)
font.setWeight(75)
self.lineEdit_6.setFont(font)
self.lineEdit_6.setObjectName("lineEdit_6")
self.pushButton_3 = QtWidgets.QPushButton(Form)
self.pushButton_3.setGeometry(QtCore.QRect(40, 430, 61, 31))
self.pushButton_3.setObjectName("pushButton_3")
self.pushButton_4 = QtWidgets.QPushButton(Form)
self.pushButton_4.setGeometry(QtCore.QRect(110, 430, 61, 31))
self.pushButton_4.setObjectName("pushButton_4")
self.pushButton_5 = QtWidgets.QPushButton(Form)
self.pushButton_5.setGeometry(QtCore.QRect(180, 430, 61, 31))
self.pushButton_5.setObjectName("pushButton_5")
self.pushButton_6 = QtWidgets.QPushButton(Form)
self.pushButton_6.setGeometry(QtCore.QRect(250, 430, 61, 31))
self.pushButton_6.setObjectName("pushButton_6")
self.pushButton_7 = QtWidgets.QPushButton(Form)
self.pushButton_7.setGeometry(QtCore.QRect(320, 430, 61, 31))
self.pushButton_7.setObjectName("pushButton_7")
self.pushButton_8 = QtWidgets.QPushButton(Form)
self.pushButton_8.setGeometry(QtCore.QRect(390, 430, 61, 31))
self.pushButton_8.setObjectName("pushButton_8")
self.pushButton_9 = QtWidgets.QPushButton(Form)
self.pushButton_9.setGeometry(QtCore.QRect(460, 430, 61, 31))
self.pushButton_9.setObjectName("pushButton_9")
self.pushButton_10 = QtWidgets.QPushButton(Form)
self.pushButton_10.setGeometry(QtCore.QRect(530, 430, 61, 31))
self.pushButton_10.setObjectName("pushButton_10")
self.retranslateUi(Form)
QtCore.QMetaObject.connectSlotsByName(Form)
def retranslateUi(self, Form):
_translate = QtCore.QCoreApplication.translate
Form.setWindowTitle(_translate("Form", "Form"))
self.label.setText(_translate("Form", "输入数据:"))
self.label_2.setText(_translate("Form", "预测值"))
self.pushButton.setText(_translate("Form", "重新预测"))
self.pushButton_2.setText(_translate("Form", "继续预测"))
self.lineEdit.setText(_translate("Form", "0"))
self.lineEdit_2.setText(_translate("Form", "0"))
self.lineEdit_3.setText(_translate("Form", "0"))
self.label_3.setText(_translate("Form", "概率"))
self.lineEdit_4.setText(_translate("Form", "0"))
self.lineEdit_5.setText(_translate("Form", "0"))
self.lineEdit_6.setText(_translate("Form", "0"))
self.pushButton_3.setText(_translate("Form", "金星"))
self.pushButton_4.setText(_translate("Form", "木星"))
self.pushButton_5.setText(_translate("Form", "水星"))
self.pushButton_6.setText(_translate("Form", "火星"))
self.pushButton_7.setText(_translate("Form", "土星"))
self.pushButton_8.setText(_translate("Form", "天王星"))
self.pushButton_9.setText(_translate("Form", "海王星"))
self.pushButton_10.setText(_translate("Form", "开普勒"))
欢迎点赞加关注,有问题请下方评论!