03_PyTorch 模型训练[Dataset 类读取数据集]

PyTorch 读取图片,主要是通过 Dataset 类 ,所以先简单了解一下 Dataset 类。 Dataset
类作为所有的 datasets 的基类存在,所有的 datasets 都需要继承它,类似于 C++ 中的虚基
类。
03_PyTorch 模型训练[Dataset 类读取数据集]_第1张图片

 

这里重点看 getitem 函数, getitem 接收一个 index ,然后返回图片数据和标签,这个
index 通常指的是一个 list index ,这个 list 的每个元素就包含了图片数据的路径和标签信
息。
然而,如何制作这个 list 呢,通常的方法是将图片的路径和标签信息存储在一个 txt
中,然后从该 txt 中读取。
那么读取自己数据的基本流程就是:
1. 制作存储了图片的路径和标签信息的 txt
2. 将这些信息转化为 list ,该 list 每一个元素对应一个样本
3. 通过 getitem 函数,读取数据和标签,并返回数据和标签
因此,要让 PyTorch 能读取自己的数据集,只需要两步:
1. 制作图片数据的索引
2. 构建 Dataset 子类
1.生成记事本代码
import os
base_dir = "E:/pytorch_learning" #修改为当前Data 目录所在的绝对路径 
'''
    为数据集生成对应的txt文件
'''
base_dir = "E:/pytorch_learning" #修改为当前Data 目录所在的绝对路径 
train_txt_path = os.path.join(base_dir, "Data", "train.txt")
train_dir = os.path.join(base_dir, "Data", "train")
valid_txt_path = os.path.join(base_dir, "Data", "valid.txt")
valid_dir = os.path.join(base_dir, "Data", "valid")
print(train_txt_path)
print(train_dir)
print(valid_txt_path)
print(valid_dir)
def gen_txt(txt_path, img_dir):
    f = open(txt_path, 'w')
    
    for root, s_dirs, _ in os.walk(img_dir, topdown=True):  # 获取 train文件下各文件夹名称
        for sub_dir in s_dirs:
            i_dir = os.path.join(root, sub_dir)             # 获取各类的文件夹 绝对路径
            img_list = os.listdir(i_dir)                    # 获取类别文件夹下所有png图片的路径
            for i in range(len(img_list)):
                if not img_list[i].endswith('png'):         # 若不是png文件,跳过
                    continue
                label = img_list[i].split('_')[0]
                img_path = os.path.join(i_dir, img_list[i])
                line = img_path + ' ' + label + '\n'
                f.write(line)
    f.close()
gen_txt(train_txt_path, train_dir)
gen_txt(valid_txt_path, valid_dir)

2.效果

03_PyTorch 模型训练[Dataset 类读取数据集]_第2张图片

 03_PyTorch 模型训练[Dataset 类读取数据集]_第3张图片

 3.Dataset类代码

class MyDataset(Dataset):
    def __init__(self, txt_path, transform=None, target_transform=None):
        fh = open(txt_path, 'r') 
        imgs = []
        for line in fh:
            line = line.rstrip() #rstrip函数返回字符串副本,该副本是从字符串最右边删除了参数指定字符后的字符串,不带参数进去则是去除最右边的空格
            words = line.split() #默认以空格为分隔符
            imgs.append((words[0], int(words[1])))

        self.imgs = imgs        # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
        # transform 是一个 Compose 类型,里边有一个 list,list 中就会定义了各种对图像进行处理的操作,
        #可以设置减均值,除标准差,随机裁剪,旋转,翻转,仿射变换等操作
        #在这里我们可以知道,一张图片读取进来之后,会经过数据处理(数据增强),
        #最终变成输入模型的数据。这里就有一点需要注意,PyTorch 的数据增强是将原始图片进行了处理
        #并不会生成新的一份图片,而是“覆盖”原图
        self.target_transform = target_transform
        self.transform = transform 

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        #对图片进行读取
        img = Image.open(fn).convert('RGB')     # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.imgs)

4.dataload

Mydataset 构建好,剩下的操作就交给 DataLoder ,在 DataLoder 中,会触发
Mydataset 中的 getiterm 函数读取一张图片的数据和标签,并拼接成一个 batch 返回,作为
模型真正的输入。
03_PyTorch 模型训练[Dataset 类读取数据集]_第4张图片
03_PyTorch 模型训练[Dataset 类读取数据集]_第5张图片

 

你可能感兴趣的:(pytorch,深度学习,人工智能)