【学习笔记】【Pytorch】一、Dataset类代码实战

【学习笔记】【Pytorch】一、Dataset类代码实战

  • 学习地址
  • 主要内容
  • 代码实现

学习地址

PyTorch深度学习快速入门教程.

主要内容

【学习笔记】【Pytorch】一、Dataset类代码实战_第1张图片
补充知识点
1.dir()函数,能让我们知道工具箱以及工具箱中的分隔区有什么东西。即查看包内的模块;
2.help()函数,能让我们知道每个工具是如何使用的,工具的使用方法。即模块内函数的使用方法。
pytorch 加载数据的规则:定义一个class类,继承 Dataset (from torch.utils.data import Dataset),并且,在类中,必须重写__getitem__(按索引返回img、label)、选重写__len__(返回数据集大小)

代码实现

from torch.utils.data import Dataset
import os
from PIL import Image


class MyData(Dataset):

    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)  # 合并
        self.img_path = os.listdir(self.path)  # 当前路径下所有文件组成的列表

    def __getitem__(self, item):
        img_name = self.img_path[item]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        img = Image.open(img_item_path)  # 使用PIL的Image打开图片
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path)  # 对列表取长,即返回当前类别路径下的图片个数


ants_dataset = MyData("dataset/train", "ants")
bees_dataset = MyData("dataset/train", "bees")

train_dataset = ants_dataset + bees_dataset

Debug模式下
【学习笔记】【Pytorch】一、Dataset类代码实战_第2张图片
【学习笔记】【Pytorch】一、Dataset类代码实战_第3张图片

你可能感兴趣的:(Pytorch,pytorch,python,深度学习,学习,人工智能)