Pytorch Dataset、Dataloader的简单理解与使用

本文以torch.utils.data中的Dataset类为例进行说明

Dataset的作用是构建自定义的数据集,以方便使用Dataloader进行加载

语法

我们自定义的数据集需要继承自torch.util.data.Dataset抽象类,并重写相应的两个方法:

  • len:返回数据集的大小。一般情况而言直接用 len(xxx) 进行实现即可
  • getitem:使得 dataset[i] 能够返回数据集中的第i个样本,相应的需要传入一个索引i

原抽象类中相应的定义如下:

def __getitem__(self, index):
    raise NotImplementedError

def __len__(self):
    raise NotImplementedError

数据

假设我们在解决一个分类问题。那么,在训练集文件夹train中,我们可以这么给图片加上标签:
Pytorch Dataset、Dataloader的简单理解与使用_第1张图片
到时候就可以通过文件名的方式来判断某张图片对应的分类。

例子

我们构造一个FruitDataset来处理这些数据。首先实现init方法:

def __init__(self, root_dir, transform=None): 
    self.root_dir = root_dir   
    self.transform = transform 
    self.images = os.listdir(self.root_dir)

init方法一般会有两个基础的参数,一个是dir,用来表示数据集所在的目录;另一个则是transform,可以传入一些方法,以对图片进行处理(一般是进行数据增强)。
此外,在init方法中还会进行基础的数据读取,例如这里使用listdir来列出目录下的所有文件;而如果是表格形式的数据(如kaggle),那么则可以使用切片方法将标签与数据分离,方便后续的处理。

接下来是len方法,即返回数据集的长度。既然我们刚才已经读出了数据集目录下的所有文件,那么只要返回这个文件夹列表的长度即可:

def __len__(self):
    return len(self.images)

最后则是getitem方法。getitem方法返回的是一个字典,表示相应数据所蕴含的其他信息,有了其他信息一个数据才能变成一个样本。在这里,“其他信息”就是图片所对应的标签,即要返回一个{‘image’:img, ‘label’:label}。习惯上,我们会把这个字典记做sample。
img可以通过imread方法读取图片实际内容得到,而label可以通过处理文件名获得:

def __getitem__(self,index):
	# 通过路径与索引读图片
    image_index = self.images[index]
    img_path = os.path.join(self.root_dir, image_index)
    img = io.imread(img_path)
    # 通过文件名读标签
    label = img_path.split('\\')[-1].split('.')[0]
    # 组装成字典
    sample = {'image':img,'label':label}
    if self.transform:
        sample = self.transform(sample)
    return sample 

注意这里的if self.transform也算是一种习惯上的用法,即如果传入了变换方法则进行变换后再返回。

Dataloader

我们通过dataloader来分析刚才构建的数据集。一般来说,训练集与测试集各会对应一个dataloader,这里为了演示方便起见就只拿我们刚才的训练集进行说明。
首先,实例化一个Dataset对象。在这里我们没有变换方法,则只需要传入数据所在的目录即可:

data = FruitDataset(r"data\train", transform=None)

dataset对象可以通过下标来访问其中的各个样本,比如:

print(data[0])

然后利用dataloader进行加载:

dataloader = DataLoader(data, batch_size=2, shuffle=True)

一般而言Dataloader需要传入三个参数:

  • dataset:传入Dataset对象,表示需要加载的数据集
  • batch_size:“批大小”,表示一次选取的一批中有几个样本。在这里bs为2,即每轮选取2个样本
  • shuffle:是否需要将数据打乱。一般来说只需要打乱训练集即可,测试集并不需要打乱

查看dataloader的长度。总共有10张图,一批有2张,因此有5批,长度为5:

# 5
print(len(dataloader)) 

最后迭代整个数据集:

for i_batch, batch_data in enumerate(dataloader):
    print(i_batch)
    print(batch_data)

i_batch就是batch的编号,0、1、2、3、4;batch_data就是我们在数据集中定义的sample,在这里两个两个一组出现。

完整代码

# -*- coding: utf-8 -*-
from torch.utils.data import Dataset, DataLoader
from skimage import io
import os

class FruitDataset(Dataset): 
    def __init__(self, root_dir, transform=None): 
        self.root_dir = root_dir  
        self.transform = transform 
        self.images = os.listdir(self.root_dir)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self,index):
        image_index = self.images[index]
        img_path = os.path.join(self.root_dir, image_index)
        img = io.imread(img_path)
        label = img_path.split('\\')[-1].split('.')[0]
        sample = {'image':img,'label':label}
        if self.transform:
            sample = self.transform(sample)
        return sample 

data = FruitDataset(r"data\train", transform=None)
print(data[0])
dataloader = DataLoader(dataset=data, batch_size=2, shuffle=True)
print(len(dataloader))
for i_batch, batch_data in enumerate(dataloader):
    print(i_batch)
    print(batch_data)

参考

https://blog.csdn.net/xuan_liu123/article/details/101145366

你可能感兴趣的:(Pytorch)