pytorch加载数据

参考:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】
本文是上面视频的笔记,up主讲的特别详细,推荐观看。
在pytorch中加载数据主要涉及到两个类:Dataset 和 Dataloader
Dataset :提供一种方式去提取数据并得到label
Dataset:对数据进行打包送到网络中去,为后面的网络提供不同的数据形式。
下面是代码及说明:

from torch.utils. data import Dataset

pytorch加载数据_第1张图片
可看到说明,Dataset是一个抽象类,我们重写Dataset时要继承这个类,所有的子类都应该重写__getitem__()方法,这个方法作用是获取数据及对应的labe。同时我们可以选择性地去重写__len__方法,其作用是获取数据集长度。

例子:

这里我使用的是猫狗二分类的数据集,如图:
pytorch加载数据_第2张图片

from torch.utils. data import Dataset
from PIL import Image
import os

class Mydataset(Dataset):
    def __init__(self,root_dir, label_dir):
        self.root_dir = root_dir    ##根目录
        self.label_dir = label_dir  ##标签,也就是文件名
        self.path = os.path.join(self.root_dir,self.label_dir)  ##拼成一个完整的目录
        self.img_path = os.listdir(self.path) ##获得图片的一个list

    def __getitem__(self, idx):
        img_name = self.img_path[idx]  ##得到单个图片的名字
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)  ##得到单个图片的路径
        img = Image.open(img_item_path) ##图片数据
        label = self.label_dir ##标签
        return img, label
    def __len__(self):
        return len(self.img_path)

root_dir="D:/猫狗大战/data/train"
cat_label_dir = "cat"
dog_label_dir = "dog"
cat_dataset = Mydataset(root_dir,cat_label_dir)
dog_dataset = Mydataset(root_dir,dog_label_dir)
img, label = cat_dataset[1]
img.show()
print(label)

img, label = dog_dataset[1]
img.show()
print(label)

输出结果:
cat
dog
pytorch加载数据_第3张图片
写给自己,另外,可以参考这篇博客:
https://ptorch.com/news/215.html
fastai也可以关注以下

你可能感兴趣的:(pytorch,基础)