【PyTorch模型训练实用教程】02 让PyTorch读取数据集

目的:上一小节将cifar10的图片存储为png格式后,又划分了训练集、验证集和测试集。这一小节的目的是为了让PyTorch能读取我们的数据集。

首先读取图片的路径、标签并将其保存到txt文件中

1_3_generate_txt

# coding:utf-8
import os
'''
    为数据集生成对应的txt文件
'''

train_txt_path = os.path.join("..", "..", "Data", "train.txt")
train_dir = os.path.join("..", "..", "Data", "train")

valid_txt_path = os.path.join("..", "..", "Data", "valid.txt")
valid_dir = os.path.join("..", "..", "Data", "valid")


def gen_txt(txt_path, img_dir):
    f = open(txt_path, 'w')
    
    for root, s_dirs, _ in os.walk(img_dir, topdown=True):  # 获取 train文件下各文件夹名称
        for sub_dir in s_dirs:
            i_dir = os.path.join(root, sub_dir)             # 获取各类的文件夹 绝对路径
            img_list = os.listdir(i_dir)                    # 获取类别文件夹下所有png图片的路径
            for i in range(len(img_list)):
                if not img_list[i].endswith('png'):         # 若不是png文件,跳过
                    continue
                label = img_list[i].split('_')[0]
                img_path = os.path.join(i_dir, img_list[i])
                line = img_path + ' ' + label + '\n'
                f.write(line)
    f.close()


if __name__ == '__main__':
    gen_txt(train_txt_path, train_dir)
    gen_txt(valid_txt_path, valid_dir)

①.os.listdir() 

输入:目录路径

输出:该路径下的文件和文件夹列表

例如代码中当i_dir为\Data\train\0时,img_list返回的就是类别为0的所有图片名称

其次是构建自己的Dataset子类

1_3_mydataset

# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset


class MyDataset(Dataset):
    def __init__(self, txt_path, transform=None, target_transform=None):
        fh = open(txt_path, 'r')
        imgs = []
        for line in fh:
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0], int(words[1])))

        self.imgs = imgs        # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = Image.open(fn).convert('RGB')     # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.imgs)

代码内容概述:

MyDataset类中包含三个函数。初始化函数是将txt中的图片信息存储到列表self.imgs中,该列表中的每个元素都包含一张图片的地址和该图片对应标签。__getitem__()函数是输入一个索引index,返回该索引对应的图片和标签。__len__()函数返回图片总数。

不懂的地方:

DataLoader类和MyDataset类的作用

总结(将图片输入模型之前都要进行什么处理):

【PyTorch模型训练实用教程】02 让PyTorch读取数据集_第1张图片

你可能感兴趣的:(PyTorch模型训练实用教程,pytorch,深度学习,python)