Dataset类实践

Dataset类实践
蚂蚁蜜蜂分类数据集和下载链接https://download.pytorch.org/tutorial/hymenoptera_data.zip

Dataset:提供一种方式去获取数据及其lable

  • Q:如何获取每个数据及其lable

    重写构造方法和获取标签方法

  • Q:告诉我们总共有多少数据

    重写len方法

代码示例:

from torch.utils.data import Dataset
from PIL import Image
import os


class MyData(Dataset):

    def __init__(self, root_dir, label_dir):  # 获取所有图片的地址
        self.root_dir = root_dir  # self设置成全局变量
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)

    def __getitem__(self, idx):  # 获取图片及其标签
        img_name = self.img_path[idx]  # 获取文件名称
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)  # 文件绝对路径
        img = Image.open(img_item_path)  # 获取文件详细信息
        label = self.label_dir  # 获取标签
        return img, label	# 返回图片和标签

    def __len__(self):
        return len(self.img_path)


# 创建示例测试

root_dir = "hymenoptera_data/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)  # 获取蚂蚁的数据集
bees_dataset = MyData(root_dir, ants_label_dir)  # 获取蜜蜂的数据集

#获取整个train数据集
train_dataset = ants_dataset + bees_dataset

在控制台中进行测试

  • 对数据集中图片的相关操作

Dataset类实践_第1张图片

获取蚂蚁数据集

Dataset类实践_第2张图片

Dataset类实践_第3张图片

  • 查找当前数据集中第一个图片名称

在这里插入图片描述

  • 图片名称拼接(进行路径和标签的拼接)

在这里插入图片描述

Dataset类实践_第4张图片

Dataset类实践_第5张图片

  • 读取图片相应信息

Dataset类实践_第6张图片

Dataset类实践_第7张图片

实例化对象

Dataset类实践_第8张图片

Dataset类实践_第9张图片

  • 返回该对象 image和label

在这里插入图片描述
Dataset类实践_第10张图片

  • 结果

Dataset类实践_第11张图片

  • 改变ants_dataset[],展示第二张图
    在这里插入图片描述

Dataset类实践_第12张图片

获取蜜蜂的数据集

Dataset类实践_第13张图片

获取整个train数据集(蚂蚁+蜜蜂)

train_dataset = ants_dataset + bees_dataset

Dataset类实践_第14张图片

# 进行长度测试
len(train_dataset)
248
len(ants_dataset)
124
len(bees_dataset)
124

你可能感兴趣的:(深度学习,深度学习,人工智能)