PyTorch入门学习(一):加载数据:Dataset类代码实战

目录

一、数据加载与编号

二、数据集的组织方式

三、使用PyTorch的Dataset类


 视频教程:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】                                    


一、数据加载与编号

在机器学习和深度学习任务中,数据加载是一个非常重要的步骤。在许多情况下,我们需要将数据加载到模型中,同时还需要获取与数据相关的标签。Dataset类在PyTorch中提供了两个主要功能:

  • 如何获取每一个数据以及其标签。
  • 告诉我们总共有多少个数据点。

这两个功能是数据加载的关键部分,而Dataset类帮助我们轻松实现这些功能。

二、数据集的组织方式

数据集的组织方式有多种形式,但在本文中,我们将关注两种常见的方式:

  • 文件夹的名称即为数据的标签。
  • 文件名和标签分别位于两个不同的文件夹中,标签可以使用文本文件(如txt)进行存储。

三、使用PyTorch的Dataset类

  • 导入必要的的库
    import os
    from PIL import Image
    from torch.utils.data import Dataset
  • __init__(self, root_dir, label_dir)方法用于初始化数据集对象。它接受两个参数:root_dir是数据集的根目录,label_dir是与数据标签相关的子目录。
    class MyData(Dataset):
        def __init__(self, root_dir, label_dir):
            self.root_dir = root_dir
            self.label_dir = label_dir
            self.path = os.path.join(self.root_dir, self.label_dir)
            self.img_path = os.listdir(self.path)
  • __getitem__(self, idx)方法用于获取数据和其标签。它接受一个索引idx,通过该索引获取特定数据点的图像和标签。
        def __getitem__(self, idx):
            img_name = self.img_path[idx]
            img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
            img = Image.open(img_item_path)
            label = self.label_dir
            return img, label
  • __len__(self)方法返回数据集的长度,也就是数据点的总数。
        def __len__(self):
            return len(self.img_path)
  • 完整代码如下所示:
    # 导入必要的库
    import os
    from PIL import Image
    from torch.utils.data import Dataset
    
    # 创建一个自定义的数据集类 MyData,继承自 PyTorch 的 Dataset 类
    class MyData(Dataset):
        # 初始化方法,用于设置数据集的基本信息
        def __init__(self, root_dir, label_dir):
            # 存储数据集根目录和标签子目录
            self.root_dir = root_dir
            self.label_dir = label_dir
            # 拼接数据路径,组成完整的路径
            self.path = os.path.join(self.root_dir, self.label_dir)
            # 获取指定标签子目录下的所有文件名
            self.img_path = os.listdir(self.path)
    
        # 获取数据点的方法,根据索引(idx)返回图像和标签
        def __getitem__(self, idx):
            # 获取特定索引的图像文件名
            img_name = self.img_path[idx]
            # 拼接图像文件的完整路径
            img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
            # 使用 PIL 库打开图像文件
            img = Image.open(img_item_path)
            # 标签是当前数据点所在的标签子目录名称
            label = self.label_dir
            # 返回图像和标签
            return img, label
    
        # 获取数据集长度的方法,返回数据点的总数
        def __len__(self):
            return len(self.img_path)
    
    # 在程序的入口点运行以下代码
    if __name__ == '__main__':
        # 定义数据集的根目录
        root_dir = "dataset/train"
        # 分别指定两个不同的标签子目录
        ants_label_dir = "ants_image"
        bees_label_dir = "bees_image"
        # 创建两个数据集对象,分别用于加载不同标签的数据
        ants_dataset = MyData(root_dir, ants_label_dir)
        bees_dataset = MyData(root_dir, bees_label_dir)
        # 合并两个数据集以创建一个用于训练的数据集
        train_dataset = ants_dataset + bees_dataset
    

你可能感兴趣的:(PyTorch,pytorch,学习,深度学习)