数据模块可分为以下几部分:
● 数据的收集:Image、label
● 数据的划分:train、test、valid
● 数据的读取:DataLoader,有两个子模块,Sampler和Dataset,Sampler是对数据集生成索引index,DataSet是根据索引读取数据
● 数据预处理:torchvision.transforms模块
torch.utils.data.DataLoader():构建可迭代的数据装载器,在训练数据时,每一个for循环,就是一次iteration,就是从DataLoader中获取一个batchsize大小的数据。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
dataset:Dataset类,决定从哪读取以及如何读取数据;
batch_size:int型,批量的大小
shuffle:每个epoch的数据是否打乱
num_workers:是否进行多进程读取数据,若采取多进程,减少读取数据的时间,可以加速模型训练
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
torch.utils.data.Dataset():Dataset类,所有自定义的数据集都要继承这个类,并且复写__getitem__()这个类方法,定义数据从哪里读取以及如何读取
class Dataset(object):
def __init__(self):
pass
def __len__(self):
raise NotImplementedError
def __getitem__(self,index):
#接受一个索引,返回一个样本
raise NotImplementedError
首先,dataset的初始化,即定义数据的来源,一般有文件夹或者是用H5文件封装好的。
直接从文件夹读取图片 ,一般可以使用glob.glob()函数或者是os.listdir()来获取图片名称列表。
下面的代码是我用来超分辨率重建过程中用到的可以用来获取成对的HR,LR
import glob
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
class SRdataset(Dataset):
def __init__(self, path,scale):
"""Initialization"""
self.list_ids = glob.glob('{}/*.png'.format(path))
self.patch_size = 128
self.scale=scale
def __len__(self):
"""Denotes the total number of samples"""
return len(self.list_ids)
def __getitem__(self, index):
# 读取数据
img = Image.open(self.list_ids[index])
##
self.width,self.heigt=img.size[0],img.size[1]
self.width=self.width//self.scale*self.scale
self.height=self.heigt//self.scale*self.scale
hr=img.resize((self.width,self.heigt),Image.BICUBIC)
##将图片随机裁剪到固定大小
hr=transforms.RandomCrop((self.patch_size, self.patch_size))(img)
##将图片转换为ycbcr格式
hr = hr.convert('YCbCr')
hr = hr.getchannel(0)
###将图片下采样得到低分辨率图片
lr=hr.resize((int(hr.size[0] // self.scale), int(hr.size[1] //self.scale)), Image.BICUBIC)
return hr,lr
# sr=SRdataset('./Set14',3)
# print(len(sr))
# hr,lr=sr[0]
# print('hr size',hr.size)
# print('lr size',lr.size)
# hr.show()
# lr.show()
上面的操作比较麻烦,也可以用torchvision中的transform进行变换
import glob
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
class SRdataset(Dataset):
def __init__(self, path,scale):
"""Initialization"""
self.list_ids = glob.glob('{}/*.png'.format(path))
self.patch_size = 100
self.scale=scale
def __len__(self):
"""Denotes the total number of samples"""
return len(self.list_ids)
def __getitem__(self, index):
# 读取数据
img = Image.open(self.list_ids[index])
hr_transform = transforms.Compose([
transforms.Resize((img.size[1]//self.scale*self.scale, img.size[0]//self.scale*self.scale),Image.BICUBIC),
transforms.RandomCrop((self.patch_size//self.scale*self.scale, self.patch_size//self.scale*self.scale), padding=4)
])
hr=hr_transform(img)
lr_transform=transforms.Compose([
transforms.Resize((hr.size[1] // self.scale , hr.size[0] // self.scale),
Image.BICUBIC),
])
lr=lr_transform(hr)
# ##将图片转换为ycbcr格式
# # hr = hr.convert('YCbCr')
# # hr = hr.getchannel(0)
return hr,lr
sr=SRdataset('./Set14',3)
print(len(sr))
hr,hr1=sr[0]
print('hr size',hr.size)
print('hr1 size',hr1.size)
● DataLoader的作用就是提供一个数据装载器,根据batch size的大小,将数据分成一个个batch去训练模型,而分数据的这个过程需要把数据读取到,这个借助Dataset中的__getitem__方法来获取样本数据。
● 在构建自定义数据集时,需要继承Dataset,并且复现__getitem__方法,实现数据怎么读,另外要重写__len__方法,返回多少个数据样本
pytroch学习笔记三:数据的读取机制_Dear_林的博客-CSDN博客