vit transformer时序预测

界面:
vit transformer时序预测_第1张图片
输入前60个星球的名字,预测第61个星球名字。

代码部分:
数据格式:
vit transformer时序预测_第2张图片
训练代码:
train.py

from models import model
from dataset import Mydata
import torch
import warnings
warnings.filterwarnings("ignore")

path=input("请输入数据集路径:")
epochs=input("请输入训练次数(整数):")
mydata = Mydata(path=path)
dataloader = torch.utils.data.DataLoader(dataset=mydata, batch_size=10, shuffle=False)

device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model=model.to(device)
#定义优化器和损失函数
optimizer=torch.optim.SGD(model.parameters(),lr=0.001)
Loss=torch.nn.CrossEntropyLoss()



for epoch in range(int(epochs)):
    epoch += 1

    # 前向传播
    model.train()
    for i,(d,l) in enumerate(dataloader):
        # print(i)
        d=d.to(device)
        l = l.to(device)

        # 梯度清零
        optimizer.zero_grad()
        out = model(d)
        # 计算损失
        loss=Loss(out,l)
        # print(loss)

        # 反响传播
        loss.backward()
        # 更新权重
        optimizer.step()
        if i % 50 == 0:
            print('第{}个epoch: loss: {}'.format(epoch,loss.item()))

torch.save(model.state_dict(),"mymodel.pth")
#load
# model = ModelClass()##初始化你的网络
# model_state_dict = torch.load("yourpath.pth")
# model.load_state_dict(model_state_dict)

models.py

import torch
from torch import nn, einsum
import torch.nn.functional as F
import torch

from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torchsummary import summary

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads # 64 x 8
        self.heads = heads # 8
        self.scale = dim_head ** -0.5

        #(b,50,128) => (b,50,64*3*8)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask = None):
        b, n, _, h = *x.shape, self.heads # n=50,_=128,h=8
        # self.to_qkv(x)得到的尺寸为[b,50,64x8x3],然后chunk成3份
        # 也就是说,qkv是一个三元tuple,每一份都是[b,50,64x8]的大小
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        # 把每一份从[b,50,64x8]变成[b,8,50,64]的形式
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
		# 这一步不太好理解,q和k都是[b,8,50,64]的形式,50理解为特征数量,64为特征变量
        # dots.shape=[b,8,50,50]
        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        # 不考虑mask这一块的内容
        mask_value = -torch.finfo(dots.dtype).max

        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value = True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = mask[:, None, :] * mask[:, :, None]
            dots.masked_fill_(~mask, mask_value)
            del mask
		# 对[b,8,50,50]的最后一个维度做softmax
        attn = dots.softmax(dim=-1)

		# 这个attn就是计算出来的自注意力值,和v做点乘,out.shape=[b,8,50,64]
        out = torch.einsum('bhij,bhjd->bhid', attn, v)
        # out.shape变成[b,50,8x64]
        out = rearrange(out, 'b h n d -> b n (h d)')
        # out.shape重新变成[b,60,128]
        out =  self.to_out(out)
        return out

class PreNorm(nn.Module):
    def __init__(self,dim,fn):
        super(PreNorm, self).__init__()
        self.norm=nn.LayerNorm(dim)
        self.fn=fn
    def forward(self,x,**kwargs):
        aa=self.norm(x)
        return self.fn(aa,**kwargs)

class Residual(nn.Module):
    def __init__(self,fn):
        super(Residual, self).__init__()
        self.fn=fn
    def forward(self,x,**kwargs):
        return self.fn(x,**kwargs)+x

class FeedForward(nn.Module):
    def __init__(self,dim,hidden_dim,dropout=0.):
        super(FeedForward, self).__init__()
        self.net=nn.Sequential(
            nn.Linear(dim,hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim,dim),
            nn.Dropout(dropout)
        )
    def forward(self,x):
        return self.net(x)

class Transformer(nn.Module):
    def __init__(self,dim,depth,heads,dim_head,mlp_dim,dropout):
        super(Transformer, self).__init__()

        self.fn=nn.Linear(8,dim,bias=False)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        self.layers=nn.ModuleList([])
        for i in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim,Attention(dim,heads=heads,dim_head=dim_head,dropout=dropout))),
                Residual(PreNorm(dim,FeedForward(dim,mlp_dim,dropout=dropout)))
            ]))

        # for attn, ff in self.layers:
        #     print("*" * 20)
        #     print(attn)
        # self.out=nn.Linear()
        self.out=nn.Sequential(
            nn.LayerNorm(16),
            nn.Linear(16,8,bias=False)
        )
    def liners(self,x,a,b):
        cc=nn.Linear(a,b,bias=False)(x)
        return cc
    def forward(self,x,mask=None):
        x=self.fn(x)
        cls_token=self.cls_token.repeat(x.shape[0],1,1)

        x=torch.cat((cls_token,x),dim=1)
        for attn ,ff in self.layers:
            x=attn(x,mask=mask)
            x=ff(x)
        y=x[:,0,:]
        # y=self.liners(y,y.shape[1],8)
        y=self.out(y)
        return y
model=Transformer(dim=16,depth=10,heads=8,dim_head=25,mlp_dim=16,dropout=.2)
if __name__ == '__main__':


    print(summary(model,input_size=(60,8),batch_size=10))
    #print(model)


dataset.py

import numpy as np
from torch import nn
import torch
from torch.utils.data import DataLoader

class Mydata(torch.utils.data.DataLoader):
    def __init__(self,path='data.txt'):
        f = open(path, 'r', encoding='utf-8')
        datas = f.read()
        f.close()
        datas=datas.replace(',',',')
        f2 = open('font_num.txt', 'w')
        datas = datas.strip().split(',')
        for i in datas:
            i=i.strip()
            if i == '金星':
                f2.write('1')
            elif i == "木星":
                f2.write('2')
            elif i == "水星":
                f2.write('3')
            elif i == "火星":
                f2.write('4')
            elif i == "土星":
                f2.write('5')
            elif i == "天王星":
                f2.write('6')
            elif i == "海王星":
                f2.write('7')
            elif i == "开普勒":
                f2.write('8')
        f2.close()

        f3=open('font_num.txt','r')
        imgs_lable=[]
        all_data=f3.read()
        for i,index in enumerate(range(61,len(all_data))):
            # print(i,' to ',index)
            # print(all_data[i:index])
            img=[int(nn) for nn in all_data[i:index]]
            imgs_lable.append(img)
        self.data=imgs_lable
        self.eye=np.eye(8).astype('float32')
    def __getitem__(self, item):

        dd=self.data[item][:60]
        ll=self.data[item][60]-1
        dd=np.array(dd)-1
        dd=self.eye[dd]
        return dd,ll
    def __len__(self):
        return len(self.data)


if __name__ == '__main__':
    mydata = Mydata(path='data2.txt')
    dataloader = torch.utils.data.DataLoader(dataset=mydata, batch_size=10, shuffle=False)
    for date ,labe in dataloader:
        print(date.shape,labe)

下面是预测代码,pred.py

import numpy as np
from models import model
import torch
import torch.nn.functional as F
import matplotlib.pylab as plt


# f=open('font_num.txt','r')
# lab=(f.read()).strip()
# lab=[int(i)-1 for i in lab]
#
#
model.eval()
model.load_state_dict(torch.load('mymodel.pth'))
eye=torch.eye(8)
def predict(inp):
    # for i in range(140):
    #
    #     inp=eye[new[i:60+i]]
    #     inp = torch.unsqueeze(inp, dim=0)
    #
    #     out = model(inp)
    #     out2=torch.argmax(out,dim=-1)
    #     new.append(out2.item())
    #
    # plt.figure('lable')
    # plt.plot(lab[:200])
    #
    # plt.figure('pred')
    # plt.plot(new)
    # plt.show()
    inp-=1
    inp=eye[inp]
    inp=torch.unsqueeze(inp,dim=0)
    out=model(inp)
    out=F.softmax(out,dim=-1)

    # print(out)
    out2,indice=torch.sort(out,dim=-1,descending=True)
    out2=(torch.squeeze(out2)).detach().numpy()
    indice = (torch.squeeze(indice)).detach().numpy()

    outcom=[((round(out2[0],3)),indice[0]),(round(out2[1],3),indice[1]),(round(out2[2],3),indice[2])]
    # print(out2)
    # print(outcom)
    return outcom

if __name__ == '__main__':
    inp=np.ones(60).astype('int32')
    out=predict(inp)
    print(out)

下面是程序的主入口main.py,直接运行该代码即可弹出界面,实现预测

import numpy as np
from PyQt5.QtWidgets import *
from PyQt5 import QtWidgets
from PyQt5.QtGui import QPixmap,QImage
from PyQt5 import QtGui
from untitled import Ui_Form
from PyQt5.QtWidgets import QFileDialog
import cv2
import sys
from pred import predict
import warnings
warnings.filterwarnings("ignore")

class My(QtWidgets.QWidget,Ui_Form):
    def __init__(self):
        super(My,self).__init__()
        self.setupUi(self)
        self.use_palette()
        self.pushButton.clicked.connect(self.start)
        self.pushButton_2.clicked.connect(self.continues)

        self.pushButton_3.clicked.connect(self.q3)
        self.pushButton_4.clicked.connect(self.q4)
        self.pushButton_5.clicked.connect(self.q5)
        self.pushButton_6.clicked.connect(self.q6)
        self.pushButton_7.clicked.connect(self.q7)
        self.pushButton_8.clicked.connect(self.q8)
        self.pushButton_9.clicked.connect(self.q9)
        self.pushButton_10.clicked.connect(self.q10)

        self.textEdit.insertPlainText('金星,木星,水星,火星,土星,天王星,海王星,开普勒,金星,木星,水星,火星,土星,天王星,海王星,开普勒,金星,木星,水星,火星,土星,天王星,海王星,开普勒,'
                                      '金星,木星,水星,火星,土星,天王星,海王星,开普勒,金星,木星,水星,火星,土星,天王星,海王星,开普勒,金星,木星,水星,火星,土星,天王星,海王星,开普勒,'
                                      '金星,木星,水星,火星,土星,天王星,海王星,开普勒,金星,木星,水星,火星')

        self.name = ['金星', '木星', '水星', '火星', '土星', '天王星', '海王星', '开普勒']
    def use_palette(self):
        self.setWindowTitle("时序预测")
        window_pale = QtGui.QPalette()
        window_pale.setBrush(self.backgroundRole(), QtGui.QBrush(QtGui.QPixmap("123.jpg")))
        self.setPalette(window_pale)

    def start(self):




        data=self.textEdit.toPlainText()

        # data = self.textEdit.toPlainText()
        data = data.replace("\n", "")
        data = data.split(',')
        # data=data.replace("\n","")
        # if len(data)!=60:
        #     msg_box = QMessageBox(QMessageBox.Warning, '警告', '数据有{}个,必须输入60个'.format(len(data)))
        #     msg_box.exec_()
        #     # print(data)
        #     return
        #
        # datas=[]
        # for i in data:
        #     if i in ['1','2','3','4','5','6','7','8']:
        #         datas.append(int(i))
        #
        #     else:
        #         msg_box = QMessageBox(QMessageBox.Warning, '警告', '数据中必须是1-8的整数')
        #         msg_box.exec_()
        #         return
        #
        # datas = np.array(datas)
        datas=self.verify(data)
        if datas == '0':
            return
        outcom=predict(datas)
        self.lineEdit.setText(str(outcom[0][0]))
        self.lineEdit_4.setText(self.name[outcom[0][1]])

        self.lineEdit_2.setText(str(outcom[1][0]))
        self.lineEdit_5.setText(self.name[outcom[1][1]])

        self.lineEdit_3.setText(str(outcom[2][0]))
        self.lineEdit_6.setText(self.name[outcom[2][1]])

        pass
    def continues(self):
        aa=self.textEdit.toPlainText()

        aa = aa.replace("\n", "")
        aa = aa.split(',')[1:]
        b=self.lineEdit_4.text()
        aa.append(b)
        now_data = ",".join(aa)
        self.textEdit.setText(now_data)
        # print(len(new_data))
        # self.textEdit.setText(new_data)
        datas=self.verify(aa)
        if datas == '0':
            return
        outcom=predict(datas)
        self.lineEdit.setText(str(outcom[0][0]))
        self.lineEdit_4.setText(self.name[outcom[0][1]])

        self.lineEdit_2.setText(str(outcom[1][0]))
        self.lineEdit_5.setText(self.name[outcom[1][1]])

        self.lineEdit_3.setText(str(outcom[2][0]))
        self.lineEdit_6.setText(self.name[outcom[2][1]])

        # print(int(aa))
        pass
    def verify(self,data):

        data2=[]
        for i in data:
            if i=='金星':
                data2.append(1)
            elif i=="木星":
                data2.append(2)
            elif i=="水星":
                data2.append(3)
            elif i=="火星":
                data2.append(4)
            elif i=="土星":
                data2.append(5)
            elif i=="天王星":
                data2.append(6)
            elif i=="海王星":
                data2.append(7)
            elif i=="开普勒":
                data2.append(8)
        if len(data2) != 60:
            msg_box = QMessageBox(QMessageBox.Warning, '警告', '数据有{}个,必须输入60个'.format(len(data)))
            msg_box.exec_()
            # print(data)
            return '0'

        # datas = []
        # for i in data:
        #     if i in ['1', '2', '3', '4', '5', '6', '7', '8']:
        #         datas.append(int(i))
        #
        #     else:
        #         msg_box = QMessageBox(QMessageBox.Warning, '警告', '数据中必须是1-8的整数')
        #         msg_box.exec_()
        #         return '0'

        datas = np.array(data2)
        return datas
    def q3(self):
        d=self.textEdit.toPlainText()
        if d:
            dd=","
        else:dd=""
        self.textEdit.insertPlainText(dd+self.name[0])
    def q4(self):
        d = self.textEdit.toPlainText()
        if d:
            dd = ","
        else:
            dd = ""
        self.textEdit.insertPlainText(dd+self.name[1])

    def q5(self):
        d = self.textEdit.toPlainText()
        if d:
            dd = ","
        else:
            dd = ""
        self.textEdit.insertPlainText(dd+self.name[2])

    def q6(self):
        d = self.textEdit.toPlainText()
        if d:
            dd = ","
        else:
            dd = ""
        self.textEdit.insertPlainText(dd+self.name[3])

    def q7(self):
        d = self.textEdit.toPlainText()
        if d:
            dd = ","
        else:
            dd = ""
        self.textEdit.insertPlainText(dd+self.name[4])

    def q8(self):
        d = self.textEdit.toPlainText()
        if d:
            dd = ","
        else:
            dd = ""
        self.textEdit.insertPlainText(dd+self.name[5])

    def q9(self):
        d = self.textEdit.toPlainText()
        if d:
            dd = ","
        else:
            dd = ""
        self.textEdit.insertPlainText(dd+self.name[6])

    def q10(self):
        d = self.textEdit.toPlainText()
        if d:
            dd = ","
        else:
            dd = ""
        self.textEdit.insertPlainText(dd+self.name[7])


if __name__ == '__main__':
    app = QtWidgets.QApplication(sys.argv)
    a=My()
    a.show()
    sys.exit(app.exec_())

下面是untitled.py代码

# -*- coding: utf-8 -*-

# Form implementation generated from reading ui file 'untitled.ui'
#
# Created by: PyQt5 UI code generator 5.15.4
#
# WARNING: Any manual changes made to this file will be lost when pyuic5 is
# run again.  Do not edit this file unless you know what you are doing.


from PyQt5 import QtCore, QtGui, QtWidgets


class Ui_Form(object):
    def setupUi(self, Form):
        Form.setObjectName("Form")
        Form.resize(677, 564)
        self.textEdit = QtWidgets.QTextEdit(Form)
        self.textEdit.setGeometry(QtCore.QRect(40, 80, 351, 311))
        font = QtGui.QFont()
        font.setFamily("Arial")
        font.setPointSize(14)
        font.setBold(True)
        font.setWeight(75)
        self.textEdit.setFont(font)
        self.textEdit.setStyleSheet("background-color: qconicalgradient(cx:0, cy:0, angle:135, stop:0 rgba(255, 255, 0, 69), stop:0.375 rgba(255, 255, 0, 69), stop:0.423533 rgba(251, 255, 0, 145), stop:0.45 rgba(247, 255, 0, 208), stop:0.477581 rgba(255, 244, 71, 130), stop:0.518717 rgba(255, 218, 71, 130), stop:0.55 rgba(255, 255, 0, 255), stop:0.57754 rgba(255, 203, 0, 130), stop:0.625 rgba(255, 255, 0, 69), stop:1 rgba(255, 255, 0, 69));")
        self.textEdit.setObjectName("textEdit")
        self.label = QtWidgets.QLabel(Form)
        self.label.setGeometry(QtCore.QRect(40, 30, 161, 31))
        font = QtGui.QFont()
        font.setFamily("Agency FB")
        font.setPointSize(18)
        font.setBold(True)
        font.setWeight(75)
        self.label.setFont(font)
        self.label.setObjectName("label")
        self.label_2 = QtWidgets.QLabel(Form)
        self.label_2.setGeometry(QtCore.QRect(520, 150, 111, 31))
        font = QtGui.QFont()
        font.setFamily("Agency FB")
        font.setPointSize(18)
        font.setBold(True)
        font.setWeight(75)
        self.label_2.setFont(font)
        self.label_2.setObjectName("label_2")
        self.pushButton = QtWidgets.QPushButton(Form)
        self.pushButton.setGeometry(QtCore.QRect(170, 480, 91, 41))
        self.pushButton.setObjectName("pushButton")
        self.pushButton_2 = QtWidgets.QPushButton(Form)
        self.pushButton_2.setGeometry(QtCore.QRect(400, 480, 91, 41))
        self.pushButton_2.setObjectName("pushButton_2")
        self.lineEdit = QtWidgets.QLineEdit(Form)
        self.lineEdit.setGeometry(QtCore.QRect(420, 200, 101, 41))
        font = QtGui.QFont()
        font.setFamily("Arial")
        font.setPointSize(20)
        font.setBold(True)
        font.setWeight(75)
        self.lineEdit.setFont(font)
        self.lineEdit.setObjectName("lineEdit")
        self.lineEdit_2 = QtWidgets.QLineEdit(Form)
        self.lineEdit_2.setGeometry(QtCore.QRect(420, 260, 101, 41))
        font = QtGui.QFont()
        font.setFamily("Arial")
        font.setPointSize(20)
        font.setBold(True)
        font.setWeight(75)
        self.lineEdit_2.setFont(font)
        self.lineEdit_2.setObjectName("lineEdit_2")
        self.lineEdit_3 = QtWidgets.QLineEdit(Form)
        self.lineEdit_3.setGeometry(QtCore.QRect(420, 320, 101, 41))
        font = QtGui.QFont()
        font.setFamily("Arial")
        font.setPointSize(20)
        font.setBold(True)
        font.setWeight(75)
        self.lineEdit_3.setFont(font)
        self.lineEdit_3.setObjectName("lineEdit_3")
        self.label_3 = QtWidgets.QLabel(Form)
        self.label_3.setGeometry(QtCore.QRect(420, 150, 61, 31))
        font = QtGui.QFont()
        font.setFamily("Agency FB")
        font.setPointSize(18)
        font.setBold(True)
        font.setWeight(75)
        self.label_3.setFont(font)
        self.label_3.setObjectName("label_3")
        self.lineEdit_4 = QtWidgets.QLineEdit(Form)
        self.lineEdit_4.setGeometry(QtCore.QRect(540, 200, 101, 41))
        font = QtGui.QFont()
        font.setFamily("Arial")
        font.setPointSize(20)
        font.setBold(True)
        font.setWeight(75)
        self.lineEdit_4.setFont(font)
        self.lineEdit_4.setObjectName("lineEdit_4")
        self.lineEdit_5 = QtWidgets.QLineEdit(Form)
        self.lineEdit_5.setGeometry(QtCore.QRect(540, 260, 101, 41))
        font = QtGui.QFont()
        font.setFamily("Arial")
        font.setPointSize(20)
        font.setBold(True)
        font.setWeight(75)
        self.lineEdit_5.setFont(font)
        self.lineEdit_5.setObjectName("lineEdit_5")
        self.lineEdit_6 = QtWidgets.QLineEdit(Form)
        self.lineEdit_6.setGeometry(QtCore.QRect(540, 320, 101, 41))
        font = QtGui.QFont()
        font.setFamily("Arial")
        font.setPointSize(20)
        font.setBold(True)
        font.setWeight(75)
        self.lineEdit_6.setFont(font)
        self.lineEdit_6.setObjectName("lineEdit_6")
        self.pushButton_3 = QtWidgets.QPushButton(Form)
        self.pushButton_3.setGeometry(QtCore.QRect(40, 430, 61, 31))
        self.pushButton_3.setObjectName("pushButton_3")
        self.pushButton_4 = QtWidgets.QPushButton(Form)
        self.pushButton_4.setGeometry(QtCore.QRect(110, 430, 61, 31))
        self.pushButton_4.setObjectName("pushButton_4")
        self.pushButton_5 = QtWidgets.QPushButton(Form)
        self.pushButton_5.setGeometry(QtCore.QRect(180, 430, 61, 31))
        self.pushButton_5.setObjectName("pushButton_5")
        self.pushButton_6 = QtWidgets.QPushButton(Form)
        self.pushButton_6.setGeometry(QtCore.QRect(250, 430, 61, 31))
        self.pushButton_6.setObjectName("pushButton_6")
        self.pushButton_7 = QtWidgets.QPushButton(Form)
        self.pushButton_7.setGeometry(QtCore.QRect(320, 430, 61, 31))
        self.pushButton_7.setObjectName("pushButton_7")
        self.pushButton_8 = QtWidgets.QPushButton(Form)
        self.pushButton_8.setGeometry(QtCore.QRect(390, 430, 61, 31))
        self.pushButton_8.setObjectName("pushButton_8")
        self.pushButton_9 = QtWidgets.QPushButton(Form)
        self.pushButton_9.setGeometry(QtCore.QRect(460, 430, 61, 31))
        self.pushButton_9.setObjectName("pushButton_9")
        self.pushButton_10 = QtWidgets.QPushButton(Form)
        self.pushButton_10.setGeometry(QtCore.QRect(530, 430, 61, 31))
        self.pushButton_10.setObjectName("pushButton_10")

        self.retranslateUi(Form)
        QtCore.QMetaObject.connectSlotsByName(Form)

    def retranslateUi(self, Form):
        _translate = QtCore.QCoreApplication.translate
        Form.setWindowTitle(_translate("Form", "Form"))
        self.label.setText(_translate("Form", "输入数据:"))
        self.label_2.setText(_translate("Form", "预测值"))
        self.pushButton.setText(_translate("Form", "重新预测"))
        self.pushButton_2.setText(_translate("Form", "继续预测"))
        self.lineEdit.setText(_translate("Form", "0"))
        self.lineEdit_2.setText(_translate("Form", "0"))
        self.lineEdit_3.setText(_translate("Form", "0"))
        self.label_3.setText(_translate("Form", "概率"))
        self.lineEdit_4.setText(_translate("Form", "0"))
        self.lineEdit_5.setText(_translate("Form", "0"))
        self.lineEdit_6.setText(_translate("Form", "0"))
        self.pushButton_3.setText(_translate("Form", "金星"))
        self.pushButton_4.setText(_translate("Form", "木星"))
        self.pushButton_5.setText(_translate("Form", "水星"))
        self.pushButton_6.setText(_translate("Form", "火星"))
        self.pushButton_7.setText(_translate("Form", "土星"))
        self.pushButton_8.setText(_translate("Form", "天王星"))
        self.pushButton_9.setText(_translate("Form", "海王星"))
        self.pushButton_10.setText(_translate("Form", "开普勒"))

欢迎点赞加关注,有问题请下方评论!

你可能感兴趣的:(transformer,pytorch,深度学习)