PyTorch深度学习笔记(六)Dataset类代码实战

课程学习笔记,课程链接

# 常用的一些工具,torch 大工具箱中的常用工具区,然后是关于数据的 data 区
from torch.utils.data import Dataset
from PIL import Image  # 使用该方法来读取图片
import os  # python中关于系统的库

'''
  Dataset 是一个抽象类,所有的数据集都需要去继承这个类
  所有的子类都应该重写 __gititem__方法,该方法主要是获取每个数据及其 label
  同时还可以选择重新其中的 __len__,即数据有多长
  注:切换下一行可按 shift+回车
'''
class MyData(Dataset):
    # 初始化,根据这个类去创建实例时就需要运行的函数
    # 该函数会为整个 class 提供全局变量,为后面的函数提供量,可最后写
    def __init__(self, root_dir, label_dir):
        # self 能够把 self 指定的变量给后面的函数使用,相当于类中的全局变量
        self.root_dir = root_dir  
        self.label_dir = label_dir
        # 将路径通过斜杠加起来(拼接)
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)  # 将文件夹下的所有文件变为列表

    '''
    首先需要获得图片的地址,可以通过索引去获取
    通过idx就能获取所有图片地址的 list(列表)
    '''
    def __getitem__(self, idx):  # idx 作为编号
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        img = Image.open(img_item_path)  # 读取
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path)


# 创建实例
root_dir = r"D:\Code\Project\learn_pytorch\pytorch_p5-6\dataset\train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)

# 小技巧
train_dataset = ants_dataset + bees_dataset  # 两个数据集的集合
img, label = train_dataset[124]
img.show()

你可能感兴趣的:(PyTorch,pytorch,深度学习,python)