pytorch实现自己制作训练集和测试集

pytorch可用于图像识别,但我们现在绝大部分用的是MINIST和cifar10图片,想要用自己的训练和测试图像路径,需要制作读取训练集和测试集的代码。本文讲述pytorch实现读取训练集和测试集通用代码。

首先讲一下读取图片路径的框架:
torch.utils.data.Dataset是一个pytorch用来表示数据集的抽象类,我们用这个类来处理自己的数据集时必须继承Dataset,然后重写下面的函数:
len:使得len(dataset)返回数据集的大小
getitem:使得支持dataset[i]能够返回第i个数据样本的下标操作。

接下来给出具体的代码片段,需要调整的是img_id = int(img_path[-12:-9]),这段是读文件名,需要根据自己的文件名来设定img_path里的数值。该文件命名为Path.py

import glob
import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

class path(Dataset):

    def __init__(self, root_path):
        self.mDataX = []
        self.mDataY = []

        for img_path in glob.glob(root_path + r'\*'):
            img = Image.open(img_path)
            img = img.convert('L')  # 转换成灰度图
            # new_size = np.array(img.size) / 4
            # new_size = new_size.astype(int)
            # img = img.resize(new_size, Image.BILINEAR)  # 从(宽,高)(640, 480)缩小为(160, 120)
            img_data = np.array(img, dtype=float)
            img_data = img_data.reshape(-1)
            self.mDataX.append(img_data)
            img_id = int(img_path[-12:-9])  ##文件名的后多少位,以后用这里需要根据图像名称更改
            self.mDataY.append(img_id)

        self.mDataX = torch.tensor(self.mDataX)
        self.mDataY = torch.tensor(self.mDataY)

    def __getitem__(self, data_index):
        input_tensor = torch.tensor(self.mDataX[data_index])
        output_tensor = torch.tensor(self.mDataY[data_index])
        return input_tensor, output_tensor

    def __len__(self):
        return len(self.mDataX)

接下来是调用上面写好的.py文件

import torch
from Path import *

train_set = Path(r'C:\datasets\人脸图像\训练样例')
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=False)  # , num_workers=8
print('OK! ', len(train_set), len(train_loader))

test_set = Path(r'C:\datasets\人脸图像\测试样例')
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False)  # , num_workers=8
print('OK! ', len(test_set), len(test_loader))

接下来可以使用你自己的训练集和测试集图像了!!!

你可能感兴趣的:(神经网络,pytorch,机器学习,深度学习)