PyTorch入门笔记2:PyTorch加载数据实战

PyTorch加载数据讲解

一、 Dataset 和 Dataloader

对于数据,我们对于已有的数据,需要用到Dataset(数据集)和DataLoader(数据加载)器。

Pytorch 读取数据主要涉及两个类:DatasetDataloader
数据可类比为“垃圾”,不同数据是不同种类的垃圾,这里蓝色是可回收垃圾。
Dataset能够把垃圾中的可回收垃圾即蓝色块给挑选出来,并对其进行编号,供后续网络的使用。

PyTorch入门笔记2:PyTorch加载数据实战_第1张图片

而数据进入网络不会是一个个送进去,在送进去之前会进行打包,比如以一次多个的形式把数据输入进网络。
总结:

Dataset提供了一种方式去获取每个数据及其label并告诉我们总共有多少的数据。
Dataloader为数据进行打包,给要训练的网络提供不同形式的数据。

二、数据集初识

数据集 蚂蚁蜜蜂分类 下载链接:https://download.pytorch.org/tutorial/hymenoptera_data.zip
解压打开查看,分为训练数据集和验证数据集。
两个文件夹都分别有分类好的蚂蚁和蜜蜂的图片,这是一个用于对蚂蚁和蜜蜂进行二分类的数据集。

三、Dataset类初识

打开jupyter,新建一个名为 read_dataset的notebook。输入下图所示代码:
可以看到
PyTorch入门笔记2:PyTorch加载数据实战_第2张图片

Dataset的使用说明表示任何数据集应该继承Dataset,并改写成员函数:__getitem____len__(可选)。

这里将数据集放到工程目录下,这样就可以用相对路径进行访问了:
代码:

from torch.utils.data import Dataset
from PIL import Image
import os

class MyData(Dataset):
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir,self.label_dir)
        # get relative address of ants pictures
        self.img_path = os.listdir(self.path)

    def __getitem__(self, idx):
        """
        :param idx: img_name
        :return: object of data,label
        """
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label
    
    def __len__(self):
        return len(self.img_path)

root_dir = "hymenoptera_data/train"
ants_label_dir ="ants"
bees_label_dir ="bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)

train_dataset = ants_dataset + bees_dataset
img, label = train_dataset[0]
print(label)
img.show()

这里就是对数据集进行简单读取,可以通过索引来对指定的数据进行图片信息和 label 读取,输出如下图所示:PyTorch入门笔记2:PyTorch加载数据实战_第3张图片

这里就是对数据集进行简单读取,可以通过索引来对指定的数据进行图片信息和 label 读取,输出如下图所示:

PyTorch入门笔记2:PyTorch加载数据实战_第4张图片

你可能感兴趣的:(Pytorch学习,pytorch,python,计算机视觉,人工智能,图像处理)