读取zs数据 dataset 训练一维卷积模型

 trainer.py

from bdb import set_trace
import torch
from torch import nn
from nets import ConvNet, LstmNet
from dataset import tempDataset
import os
from torch.utils.data import Dataset, DataLoader

class Trainer:
    def __init__(self, save_path, dataset_path, isLstm=False):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.save_path = save_path
        self.isLstm = isLstm
        self.dataset = tempDataset(dataset_path)
        if isLstm:
            self.net = LstmNet().to(self.device)
        else:
            self.net = ConvNet().to(self.device)
        self.loss_fn = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.0001)
        # if not os.path.exists(self.save_path):
        #     os.makedirs(self.save_path)
        # else:
        #     self.net.load_state_dict(torch.load(self.save_path))

    def train(self):
        epoch = 1
        loss_new = 100
        # while True:
        dataLoader = DataLoader(dataset=self.dataset)
        import pdb;set_trace()
        for num in range(20):
            for i,(x,y) in enumerate(dataLoader):
                # for i in range(self.train_data.shape[0] - 9):
                if self.isLstm:
                    input = torch.Tensor(x.reshape(-1, 9, 1)).to(self.device)
                else:
                    input = torch.Tensor(x.reshape(-1, 1, 2)).to(self.device)
                label = torch.Tensor(y.reshape(-1, 1)).to(self.device)
                output = self.net(input)
                loss = self.loss_fn(output, label)
                # import pdb;set_trace()
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                if i % 10 == 0:
                    print("epoch:{},i:{},loss:{},output:{},label:{},output-label:{:.6f}".format(epoch, i, loss.item(),
                                                                                                output.item(),
                                                                                                label.item(),
                                                                                                output.item() - label.item()))
                if loss.item() < loss_new:
                    loss_new = loss.item()
                    torch.save(self.net.state_dict(), self.save_path + '{}.pth'.format(epoch))
                epoch += 1

nets.py

from torch import nn
import torch


class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layer = nn.Sequential(
            ConvolutionalLayer(1, 8, 2, 1, 1),  # n,128,5
            ConvolutionalLayer(8, 16, 2),  # n,512,3
            ConvolutionalLayer(16, 32, 2)  # n,1024,1
        )
        self.liner_layer = nn.Sequential(
            nn.Linear(32, 1),
            # nn.Sigmoid()
        )

    def forward(self, data):
        conv = self.conv_layer(data)
        conv = conv.reshape(-1, 32)
        output = self.liner_layer(conv)
        return output


class LstmNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm_layer = nn.LSTM(1, 1024, 2, batch_first=True)
        self.liner_layer = nn.Sequential(
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, data):
        lstm_layer, (h, c) = self.lstm_layer(data)
        lstm_layer = lstm_layer[:, -1]
        lstm_layer = lstm_layer.reshape(-1, 1024)
        output = self.liner_layer(lstm_layer)
        return output


class ConvolutionalLayer(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, stride=1, padding=0, groups=1, bias=True):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Conv1d(input_channels, output_channels, kernel_size, stride, padding, groups=groups, bias=bias),
            # nn.BatchNorm1d(output_channels),
            nn.PReLU()
        )

    def forward(self, data):
        return self.layer(data)


if __name__ == '__main__':
    # net = LstmNet()
    net = ConvNet()
    data = torch.tensor([x / 1 for x in range(2)]).reshape(1, 1, 2)
    # data = torch.tensor([x / 9 for x in range(9)]).reshape(1, 9, 1)
    output = net(data)
    print(output)
    print(output.shape)

main.py

from trainer import Trainer

if __name__ == '__main__':
    trainer = Trainer("../zsr/models/net", "../zsr/ample.txt", isLstm=False)
    # trainer = Trainer("../cnn_1d/models/lstm_net.pth", "../cnn_1d/data/data2.xlsx", isLstm=True)
    trainer.train()

dataset.py

from torch.utils.data import Dataset,DataLoader
import numpy as np
import torch.optim as optim
import torch

def read_txt(path):
    """
    open the txt.
    """
    with open(path, 'r+') as f:
        flines = f.readlines()
    return flines
    
class tempDataset(Dataset):
    def __init__(self, path):
        # self.X = np.random.randn(100,3,512,512)
        self.X = []
        self.Y = []
        self.path = path 
        flines = read_txt(self.path)
        flines = list(map(lambda x: x.strip(), flines))
        for line in flines[1:]:
            line = line.split(' ')
            line = list(map(lambda x: float(x), line))
            self.Y.append(line[0])
            self.X.append([line[1], line[-1]])
        self.X = np.array(self.X, dtype=np.float32)
        self.Y = np.array(self.Y, dtype=np.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, item):
        # 这里返回的时候不要设置batch_size
        return self.X[item],self.Y[item]

if __name__ == '__main__':
    path = "./distance_sample.txt"
    # dataset = Dataset(path)
    # data = dataset.getData()
    # print(data.shape[0])
    flines = read_txt(path)
    dataset = tempDataset(path)
    dataLoader = DataLoader(dataset=dataset)
    for i,(x,y) in enumerate(dataLoader):
        import pdb; pdb.set_trace()

sample.txt

0.25 0.4236480479560859 0.0970777454099595
0.26 -1.7146637075537476 -0.46513226330560054
0.27 0.4188321031181166 1.86239657386961
0.28 -0.6878831117260947 1.2465211119719048
0.29 2.6974890810128023 0.835188151761743
0.3 1.8910601248808103 1.689956673665461
0.31 0.2629244309998292 1.083057050465252
0.32 0.5900272940038855 3.15234530797878
0.33 -0.26585576830236385 1.9378389294664315
0.34 -0.10468282807426771 2.3975277130209163
0.35 1.8152280779751142 2.1274933671930545
0.36 -0.9193467358583763 0.7240830004041563
0.37 -0.32316160541893046 1.3726967862830741
0.38 1.2830681983204184 -0.23556575804191737
0.39 -0.17460099632805381 0.5493902892851491
0.4 2.227213896418279 1.714204568215747
0.41 -1.8816158329308459 0.9542160243561777
0.42 -0.4874561918326939 1.0337512588290139
0.43 0.9296375301572501 2.408488926857565
0.44 -0.3773632465927879 1.4726729480049072
0.45 1.8579102338581668 1.1749968315283246
0.46 0.7609286927591787 0.9709886393215672
0.47 -0.6722156430805126 1.8506015511546934
0.48 1.5139102898496717 2.5177739671059816
0.49 0.424780329856696 1.5924040712267007
0.5 1.3480856240476284 1.5148655701587392
0.51 -0.03171469835768792 1.1151866427331552
0.52 0.24486653318837992 -0.6966293490183777
0.53 0.13087323931782024 -0.24814261388501202
0.54 0.7480863491741208 2.6558186256275347
0.55 1.9854865916945192 0.8986700094485898

你可能感兴趣的:(pytorch,笔记,pytorch,深度学习,神经网络)