Pytorch深度学习-----数据模块Dataset类

文章目录

      • 前言
      • 一、概念
      • 二、使用
      • 三、代码演示

前言

基于B站PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】课程进行学习总结

一、概念

dataset类是一个抽象类,即不能进行实例化,要想实例化可以使用创建一个子类来进行
作用就是将实现通过索引访问对应的数据以及标签

二、使用

step1:需要导入torch.utils.data模块
step2:创建一个子类继承Dataset
step3:实现三个函数

init :初始化函数,为自定义类设置成员变量
len :返回样本个数
getitem :核心函数,实现按索引返回数据及标签

三、代码演示

from torch.utils.data import Dataset
from PIL import Image
import os

class MyData(Dataset):  # 继承Dataset
    def __init__(self,root_dir,label_dir): # 这个方法主要存放下面要用到属性和方法
        self.root_dir = root_dir  # 相当于这里的train文件夹所在的路径
        self.label_dir = label_dir  # 标签:相当于ants文件夹
        self.path = os.path.join(self.root_dir,self.label_dir)  # 将上述相对路径以及标签进行拼接,形成一个完整的路径,用于下面将这些图片名称转化为列表
        self.img_path = os.listdir(self.path)  # 将所有数据集中的图片名称转化为一个列表

    def __getitem__(self, idx):  # 根据索引获取数据集里面的单个数据
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)  # 形成一张完整的图片路径
        img = Image.open(img_item_path) # 将某个路径的图片转化成一个图像对象
        label = self.label_dir
        return img,label  # 将图片对象和标签返回

    def __len__(self):
        return len(self.img_path) # 根据上面形成的数据集列表返回数据的长度

# 蚂蚁数据集
root_dir = "D:\\Python\\pytorch\\hymenoptera_data\\train"
ants_label_dir = "ants_image"
ants_dataset = MyData(root_dir,ants_label_dir)  # 以所在图片的上一级上一级路径和图片所在的文件夹为实参进行传入
# 蜜蜂数据集
bees_label_dir = "bees"
bees_dataset = MyData(root_dir,bees_label_dir)
img,label = ants_dataset.__getitem__(1)
print(ants_dataset.img_path)  # 将存储蚂蚁的数据以列表形式打印
img.show()  # 展示蚂蚁数据集列表中索引值为1的图片
print(f"蚂蚁数据集列表长度:{ants_dataset.__len__()}")
# 把两个数据集进行相加
toller = ants_dataset+bees_dataset


你可能感兴趣的:(pytorch深度学习,深度学习,pytorch,人工智能)