Pytorch学习笔记:自定义数据集

文章目录

  • 前言
  • 1.torch.utils.data.Dataset介绍
  • 2.实例
    • a.准备数据集
    • b.复写 Dataset
    • c.DataLoader加载
  • 总结


前言

初学Pytorch时,数据集直接使用torchvision.datasets调用,然后直接使用torch.untils.data.DataLoader加载。
在实际项目中,我们怎么自定义数据集呢?


1.torch.utils.data.Dataset介绍

`torch.utils.data.Dataset` 是一个抽象类,用户想要加载自定义的数据只需要继承这个类,并且覆写其中的三个方法即可:
  1. __init__:构建函数,self属性
  2. __len__:实现len(dataset)返回整个数据集的大小。
  3. __getitem__用来获取一些索引的数据,使dataset[i]返回数据集中第i个样本。
    不覆后面两个方法会直接返回错误。

2.实例

a.准备数据集

以分类网络为例,准备原始数据:
常见的两种形式的导入:

一是整个数据集都在一个文件下,内部再另附一个label文件,说明每个文件的状态。这种存放数据的方式可能更时候在非分类问题上得到应用

二是更适合在分类问题上,即把不同种类的数据分为不同的文件夹存放起来。这样,我们可以从文件夹或文件名得到label。或者一个统一的Label文件

b.复写 Dataset

#导入相关模块
from torch.utils.data import DataLoader,Dataset
from skimage import io,transform
import matplotlib.pyplot as plt
import os
import torch
from torchvision import transforms
import numpy as np

class AnimalData(Dataset): #继承Dataset
    def __init__(self, root_dir, transform=None): #__init__是初始化该类的一些基础参数
        self.root_dir = root_dir   #文件目录
        self.transform = transform #变换
        self.images = os.listdir(self.root_dir)#目录里的所有文件
    
    def __len__(self):#返回整个数据集的大小
        return len(self.images)
    
    def __getitem__(self,index):#根据索引index返回dataset[index]
        image_index = self.images[index]#根据索引index获取该图片
        img_path = os.path.join(self.root_dir, image_index)#获取索引为index的图片的路径名
        img = io.imread(img_path)# 读取该图片
        label = img_path.split('\\')[-1].split('.')[0]# 根据该图片的路径名获取该图片的label,具体根据路径名进行分割。我这里是"E:\\Python Project\\Pytorch\\dogs-vs-cats\\train\\cat.0.jpg",所以先用"\\"分割,选取最后一个为['cat.0.jpg'],然后使用"."分割,选取[cat]作为该图片的标签
        sample = {'image':img,'label':label}#根据图片和标签创建字典
        
        if self.transform:
            sample = self.transform(sample)#对样本进行变换
        return sample #返回该样本

c.DataLoader加载

if __name__=='__main__':
    data = AnimalData('E:/Python Project/PyTorch/dogs-vs-cats/train',transform=None)#初始化类,设置数据集所在路径以及变换
    dataloader = DataLoader(data,batch_size=128,shuffle=True)#使用DataLoader加载数据
    for i_batch,batch_data in enumerate(dataloader):
        print(i_batch)#打印batch编号
        print(batch_data['image'].size())#打印该batch里面图片的大小
        print(batch_data['label'])#打印该batch里面图片的标签

输出如下:

0
torch.Size([128, 3, 224, 224])
['dog', 'cat', 'dog', 'dog', 'cat', 'dog', 'cat', 'cat', 'dog', 'dog', 'dog', 'dog', 'cat', 'cat', 'dog', 'dog', 'cat', 'dog', 'cat', 'cat', 'cat', 'cat', 'dog', 'cat', 'cat', 'cat', 'cat', 'cat', 'dog', 'dog', 'cat', 'cat', 'dog', 'cat', 'cat', 'cat', 'cat', 'dog', 'dog', 'dog', 'cat', 'dog', 'cat', 'cat', 'cat', 'dog', 'cat', 'dog', 'dog', 'dog', 'dog', 'dog', 'cat', 'cat', 'dog', 'cat', 'cat', 'cat', 'cat', 'dog', 'dog', 'cat', 'dog', 'cat', 'cat', 'cat', 'cat', 'dog', 'cat', 'dog', 'dog', 'cat', 'dog', 'dog', 'cat', 'cat', 'dog', 'dog', 'dog', 'dog', 'dog', 'cat', 'cat', 'dog', 'dog', 'cat', 'dog', 'cat', 'dog', 'cat', 'cat', 'dog', 'cat', 'dog', 'cat', 'dog', 'cat', 'dog', 'cat', 'cat', 'cat', 'dog', 'cat', 'cat', 'dog', 'dog', 'dog', 'cat', 'cat', 'cat', 'cat', 'dog', 'cat', 'dog', 'dog', 'dog', 'dog', 'cat', 'dog', 'dog', 'dog', 'dog', 'cat', 'cat', 'dog', 'cat', 'dog', 'cat']

总结

  1. torch.utils.data.Dataset 的结构,如何自定义数据集

参考:文章

你可能感兴趣的:(Pytorch学习笔记,pytorch,深度学习,python)