PyTorch学习笔记(三)总结篇 --------自建数据集的载入

前言

经过这几天学习,我算是把数据集这一块给摸清楚了,前面分布分支的学习总是有点模棱两可,不清楚这步到底要干啥,在网上找资料学习时,总是拿的pytorch官网给的数据集,没有针对性和专一性。这里教大家如何使用咱们自己的数据集,当然,在做实验时数据集是通过爬虫来获取的,关于爬虫的相关知识可以留言私信,或者看我第一篇博客哦

一、MyData类的定义

在自建数据集时需要自己去定义一个dataset类来继承torch.utils.data.Dataset

来看代码

class MyData(Dataset):
    def __init__(self, root_dir, label_dir, transform=None):  # 初始化类,为class提供全局变量
        self.transform = transform
        self.root_dir = root_dir  # 根文件位置
        self.label_dir = label_dir  # 子文件名
        self.path = os.path.join(self.root_dir, self.label_dir)  # 合并,即具体位置
        self.img_path = os.listdir(self.path)  # 转换成列表的形式

    def __getitem__(self, idx):  # 获取列表中每一个图片
        img_name = self.img_path[idx]  # idx表示下标,即对应位置
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)  # 每一个图片的位置
        img = Image.open(img_item_path)  # 调用方法,拿到该图像
        img = img.convert("RGB")
        img = self.transform(img)
        label = self.label_dir  # 标签
        return img, label  # 返回img 图片 label 标签

    def __len__(self):  # 返回长度
        return len(self.img_path)

这里要注意一下,跟第一篇学习笔记的不同在于第一篇没有定义transform导致返回的就是PIL类型,这边加入了几行代码,目的就是为了返回tensor类型,这里返回的是img,label两个对象

二、数据集的实际运用

上面的类是通法,在任何研究中都可以套用,下面来看在本次实验中的实际运用

tensor_trans = transforms.Compose([transforms.ToTensor(), transforms.Resize([512, 512])])
root_dir = 'D://情绪图片'  # 根目录
happy_label_dir = '开心'  # 子目录
happy_dataset = MyData(root_dir, happy_label_dir, transform=tensor_trans)  # 开心数据集创建完成

这边是通过totensor转换为tensor类型,同时将图片尺寸变为512*512

也算so easy吧

三、DataLoader类

dataloader是用来load数据集,其中batch_size=4是为了每次抓取4张;shuffle是按需求来是否需要打乱,即在等于True的时候是打乱的,False的时候是不打乱的;drop_last是表示在最后一次抓取时不满4个是否需要保留,比如一共10张图片,每次抓取4个,最后一次不满4个可以选择保留或者舍弃

test_loader = DataLoader(dataset=happy_dataset, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
# img, label = happy_dataset[0]
step = 0
writer = SummaryWriter("dataloader")
for epoch in range(2):
    for data in test_loader:
        imgs, label = data
        writer.add_images('{}'.format(epoch), imgs, step)
        step = step + 1
writer.close()

四、源码

这边是最终的源码,大家可以按自己的要求选取哦

# -*- coding = utf-8 -*-
import torchvision
from torch.utils.data import Dataset
import cv2
from PIL import Image  # 图像处理的库
import os

from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torch.utils.data import DataLoader

class MyData(Dataset):
    def __init__(self, root_dir, label_dir, transform=None):  # 初始化类,为class提供全局变量
        self.transform = transform
        self.root_dir = root_dir  # 根文件位置
        self.label_dir = label_dir  # 子文件名
        self.path = os.path.join(self.root_dir, self.label_dir)  # 合并,即具体位置
        self.img_path = os.listdir(self.path)  # 转换成列表的形式

    def __getitem__(self, idx):  # 获取列表中每一个图片
        img_name = self.img_path[idx]  # idx表示下标,即对应位置
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)  # 每一个图片的位置
        img = Image.open(img_item_path)  # 调用方法,拿到该图像
        img = img.convert("RGB")
        img = self.transform(img)
        label = self.label_dir  # 标签
        return img, label  # 返回img 图片 label 标签

    def __len__(self):  # 返回长度
        return len(self.img_path)
tensor_trans = transforms.Compose([transforms.ToTensor(), transforms.Resize([512, 512])])
root_dir = 'D://情绪图片'  # 根目录
happy_label_dir = '开心'  # 子目录
happy_dataset = MyData(root_dir, happy_label_dir, transform=tensor_trans)  # 开心数据集创建完成

# img, label = happy_dataset[2]  # 由上面可知返回的是两个值
# print(label)  # 分别调用
# img.show()
# batch_size每次取dataset的四个数据集并打包, shuffle是是否打乱,drop_last为False即最后一步不满4个时不舍,反之舍
# print(happy_dataset[0])
# tensor_trans = transforms.Compose([transforms.ToTensor(), transforms.Resize([224, 224])])
# test_data = torchvision.datasets(datasets=happy_dataset, transforms=tensor_trans)

test_loader = DataLoader(dataset=happy_dataset, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
# img, label = happy_dataset[0]
step = 0
writer = SummaryWriter("dataloader")
for epoch in range(2):
    for data in test_loader:
        imgs, label = data
        writer.add_images('{}'.format(epoch), imgs, step)
        step = step + 1
writer.close()

下面来看结果

PyTorch学习笔记(三)总结篇 --------自建数据集的载入_第1张图片

 

五、总结

这几天对数据集的学习对我这个初次接触pytorch的人来说也是挺头疼的,各种报错,包括loader的基本使用当时也是不太熟悉。现在大部分教学用的数据集都是通过torchvision.dataset.来获取pytorch自带的数据集,这篇也算是给大家提供另一种方法吧

最后有什么不明白或者报错不知道怎么解决的可以留言私信哦~说不定我曾经也经历过

你可能感兴趣的:(PyTorch深度学习,pytorch,python,机器学习,深度学习)