Pytorch实现数据加载—模型搭建训练—模型保存—模型测试评估完整流程

1. Pytorch框架下对模型的数据加载,模型搭建,保存以及测试评估 

# -*- coding: utf-8 -*-
"""
Created on Mon Apr 24 13:37:50 2023

@author: 茶墨先生
"""
import torch
import numpy as np
from torch.utils.data import Dataset,DataLoader
import torch.nn.functional as F
from torch import nn
from torch.optim import Adam

'''加载训练数据'''
y_train_feature=np.load('C:\结果\y_train_speech_feature.npy')
y_test_feature=np.load('C:\结果\y_test_speech_feature.npy')
y_train=np.load('C:\结果\y_train.npy')
y_test=np.load('C:\结果\y_test.npy')

y_test_feature=(y_test_feature.reshape(-1,7,129)).astype(np.float32)
y_train_feature=(y_train_feature.reshape(-1,7,129)).astype(np.float32)
y_test=(y_test.reshape(-1,129)).astype(np.float32)
y_train=(y_train.reshape(-1,129)).astype(np.float32)

Batch_Bize=128

"数据集类"
class MyDataset(Dataset):
    def __init__(self,input_data):
         self.data=input_data
        
    def __getitem__(self,index):
        
        return self.data[index]#获取当前索引的一条数据
    
    def __len__(self):
        
        return len(self.data)#返回数据总数量
    
"模型构建"
class Mymold(nn.Module):
    
    def __init__(self):
        super(Mymold,self).__init__()#继承父类
        self.gru=nn.GRU(input_size=129,hidden_size=128,num_layers=2,batch_first=True)
        
        self.fc1=nn.Linear(7*128,512)
        self.fc2=nn.Linear(512,512)
        self.fc3=nn.Linear(512,512)
        self.fc4=nn.Linear(512,129)
        '''前向传播'''
    def forward(self,input):
        x=input.view(-1,7,129)#改变数据形状
        x,hn=self.gru(x)
    
        x=x.reshape(-1,7*128)#改变数据形状
        x=self.fc1(x)
        x=F.relu(x)
        x=self.fc2(x)
        x=F.relu(x)
        x=self.fc3(x)
        x=F.relu(x)
        x=self.fc4(x)
        return x.view(-1,129)
    
"实例化模型,优化器,损失函数"
mymold=Mymold()
mymold.train(True)
optimer=Adam(mymold.parameters(),0.001)
#loss_mymold=nn.MSELoss()
loss_mymold=nn.L1Loss()


"定义训练函数"
def train():
    "打乱数据"
    index_train=[i for i in range(len(y_train))]
    np.random.shuffle(index_train)
    y_train1=y_train[index_train,:]
    y_train_feature1=y_train_feature[index_train,:,:]

    "加载分解数据"
    my_train_feature=MyDataset(y_train_feature1)
    my_train_feature=DataLoader(dataset=my_train_feature,batch_size=Batch_Bize,shuffle=False)

    my_train_target=MyDataset(y_train1)
    my_train_target=DataLoader(dataset=my_train_target,batch_size=Batch_Bize,shuffle=False)
 
    "训练模型"
    for target,input_feature in zip(my_train_target,my_train_feature):
        optimer.zero_grad()
        output=mymold(input_feature)
        loss=loss_mymold(output,target)
        loss.backward()#反向传播
        optimer.step()#参数更新
        
#        print(loss.item())  
"测试模型" 
def test():
    "加载分解数据"
    my_test_feature=MyDataset(y_test_feature)
    my_test_feature=DataLoader(dataset=my_test_feature,batch_size=Batch_Bize,shuffle=False)

    my_test_target=MyDataset(y_test)
    my_test_target=DataLoader(dataset=my_test_target,batch_size=Batch_Bize,shuffle=False)
    
    Loss=[]
    for target,input_feature in zip(my_test_target,my_test_feature):
        output=mymold(input_feature)
        loss=loss_mymold(output,target)
        Loss.append(loss.item())
    print("TestLoss:",np.mean(Loss))
    return np.mean(Loss)


if __name__ =='__main__':
    
    Loss_Start=200.#设置原始测试损失值
    Patience=5#设置模型不更新的最大次数
    patience=1
    for i in  range(1000):
        print("epoch:",i)
        train()
        loss_test=test()#获得模型的测试损失
        if loss_test<=Loss_Start:
            "保存当前较好的模型参数"
            print("更新模型")
            torch.save(mymold.state_dict(),r"E:\Pytorch\模型\mold.pt")#保存模型参数
            torch.save(optimer.state_dict(),r"E:\Pytorch\模型\optimer.pt")#保存优化器参数
            Loss_Start=loss_test
            patience=1
        else:
            print("不更新模型")
            patience=patience+1
            if patience>Patience:
                break
        

"加载模型"
mymold.load_state_dict("E:\Pytorch\模型\mold.pt")
optimer.load_state_dict("E:\Pytorch\模型\optimer.pt")







2. 运行结果图展示 

  1. Pytorch实现数据加载—模型搭建训练—模型保存—模型测试评估完整流程_第1张图片

注:私信留言获取数据集和特征提取源码 

你可能感兴趣的:(深度学习单通道语音增强,Pytorch深度学习,pytorch,深度学习,python,人工智能,神经网络)