pytorch数据读取Dataloader与Dataset

数据
数据收集–>img,label
数据划分–>train,valid,test(详细见:https://blog.csdn.net/wyyyyyyfff/article/details/104381429)
数据读取–>dataloader–>sampler(index生成索引,样本序号),dataset(根据索引读取img,label)
数据预处理–>transforms

DataLoader
DataLoader是Pytorch中用来处理模型输入数据的一个工具类。通过使用DataLoader,我们可以方便地对数据进行相关操作,比如我们可以很方便地设置batch_size,对于每一个epoch是否随机打乱数据,是否使用多线程等等。

torch.utils.data.DataLoader(dataset, 
							batch_size=1, 
							shuffle=False, 
							num_works=0, 
							drop_last=False)

功能:构建可迭代的数据装载器
dataset:Dataset类,决定数据从哪读取以及如何读取
batch_size:批大小
shuffle:每个epoch是否乱序
num_works:是否多进程读取数据
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据

关于DataLoader,需要先了解batchsize和epoch等输入数据的相关概念,以及python中类的基本知识比如继承和函数复写。

epoch:所有训练样本都已经输入到模型中,称为一个epoch
iteration:一批样本输入到模型中,称为一个iteration
batchsize:批大小,决定一个epoch有多少iteration
样本总数:87, batchsize:8
drop_last=True–>1 epoch=10 iteration
drop_last=False–>1 epoch=11 iteration

DataLoader的基本使用流程
1.首先会将原始数据加载到DataLoader中去,如果需要shuffle的话,会对数据进行随机打乱操作,这样能够输入顺序对于数据的影响。
2.再使用一个迭代器来按照设置好的batch大小来迭代输出shuffle之后的数据。
Tips: 通过使用迭代器能够有效地降低内存的损耗,会在需要使用的时候才将数据加载到内存中去。

Dataset 解决数据从哪里读取以及如何读取
功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写_getitem_()
getitem:接收一个索引,返回一个样本

使用Dataset来创建自己的数据类:

  1. 继承torch.utils.data.Dataset这个类
  2. 复写__getitem__ 和 __ len__ 这两个方法
  3. 如下图接收一个index,返回样本以及标签(如何读取样本,用户编写getitem)
    torch.utils.data.Dataset
class Dataset(object):
	def _getitem_(self, index):
		raise NotImplementedError
	def __len__(self):
        raise NotImplementedError
	def _add_(self, other):
		return ConcatDataset([self, other])

例子:

class MyDataset(Dataset): 
    """ my dataset."""
    
    # Initialize your data, download, etc.
    def __init__(self):
        # 读取csv文件中的数据
        xy = np.loadtxt('.csv', delimiter=',', dtype=np.float32) 
        self.len = xy.shape[0]
        # 除去最后一列为数据位,存在x_data中
        self.x_data = torch.from_numpy(xy[:, 0:-1])
        # 最后一列为标签为,存在y_data中
        self.y_data = torch.from_numpy(xy[:, [-1]])
        
    def __getitem__(self, index):
        # 根据索引返回数据和对应的标签
        return self.x_data[index], self.y_data[index]
        
    def __len__(self): 
        # 返回文件数据的数目
        return self.len

数据读取
1.读哪些数据 sampler输出的index
2.从哪读数据 Dataset中的data_dir
3.怎么读数据 Dataset中的getitem
pytorch数据读取Dataloader与Dataset_第1张图片
os.path.join(从哪里读数据,数据路径)

import os
split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
print(split_dir)
print(train_dir)
print(valid_dir)

…\data\rmb_split
…\data\rmb_split\train
…\data\rmb_split\valid

import shutil
shutil – Utility functions for copying and archiving files and directory trees.
(用于复制和存档文件和目录树的实用功能。)

详细见:https://blog.csdn.net/wyyyyyyfff/article/details/104381429

from PIL import Image
详细:https://www.cnblogs.com/lyrichu/p/9124504.html
https://blog.csdn.net/Li_qf/article/details/84925027
https://blog.csdn.net/zhangziju/article/details/79123275?utm_source=distribute.pc_relevant.none-task

你可能感兴趣的:(pytorch数据读取Dataloader与Dataset)