Dataloader
与Dataset
前面学习到机器学习训练的五个步骤为:
Sampler
和DataSet
;Sample
的功能是生成索引,也就是样本的序号Dataset
是根据索引去读取数据以及对应的标签pytorch
中数据预处理是通过transforms
进行处理Dataloader
DataLoader(dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None)
Dataloader
的参数非常多,共有11个参数,但常用的就是下面五个:dataset
:Dataset
类,决定数据从哪里读取及如何读取batchsize
:批大小num_works
:是否多进程读取数据shuffle
:每个epoch是否乱序drop_last
:当样本数不能被batchsize整除时,是否舍弃最后一批数据epoch
:所有训练样本都已输入到模型中,称为一个epoch
,1个epoch
表示过了1遍训练集中的所有样本iteration
:一批样本输入到模型中,称之为一个iteration
(training step),每次迭代更新1次网络结构的参数batchsize
:批大小,表示一次迭代所使用的样本量,决定一个epoch
中有多少个iteration
drop_last
作用:样本总数 | Batchsize | drop_last | Epoch |
---|---|---|---|
87 | 8 | true | = 10 iteration |
87 | 8 | false | = 11 iteration |
Dataset
class Dataset(object):
def __getitem__(self, index):
raise NotImplementedError
def __add__(self, other)
return ConcatDataset([self,other])
Dataset
抽象类,所有自定义的Dataset
需要继承它,并且复写__getitem__()
getitem
:接收一个索引,返回一个样本对人民币二分类的数据进行读取,从以下三个方面了解Pytorch
的读取机制:
dataset_dir = os.path.join("/tmp/pytorch学习/WeekTwo/lesson-6/", "data", "RMB_data")
split_dir = os.path.join("/tmp/pytorch学习/WeekTwo/lesson-6/", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")
test_dir = os.path.join(split_dir, "test")
Resize
是对数据进行缩放RandomCrop
是对数据进行裁剪(起到数据增强的效果)ToTensor
是对数据进行转换,把图像转换成张量数据train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
valid_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
注意:训练集中用到了RandomCrop
进行裁剪,但测试集中不需要要进行数据增强操作
Dataset
和DataLoader
Dataset
:必须是用户自己构建的,在Dataset
中会传入两个主要参数:
data_dir
:数据的路径(从哪里读取数据)transform
:数据预处理Dataloader
:构建数据迭代器,有两个主要参数:
Dataset
:前面构建好的RMBDataset
batch_size
:shuffle=True
表示每一个epoch中样本都是乱序的Dataset构建代码:
# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)
RMBDataset
的具体实现
class RMBDataset(Dataset):
def __init__(self, data_dir, transform=None):
"""
rmb面额分类任务的Dataset
:param data_dir: str, 数据集所在路径
:param transform: torch.transform,数据预处理
"""
self.label_name = {"1": 0, "100": 1} # 初始化部分
self.data_info = self.get_img_info(data_dir) # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
self.transform = transform
def __getitem__(self, index): # 函数功能是根据index索引去返回图片img以及标签label
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB') # 0~255
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self): # 函数功能是用来查看数据的长度,也就是样本的数量
return len(self.data_info)
@staticmethod
def get_img_info(data_dir): # 函数功能是用来获取数据的路径以及标签
data_info = list()
for root, dirs, _ in os.walk(data_dir):
# 遍历类别
for sub_dir in dirs:
img_names = os.listdir(os.path.join(root, sub_dir))
img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
# 遍历图片
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.join(root, sub_dir, img_name)
label = rmb_label[sub_dir]
data_info.append((path_img, int(label)))
return data_info # 有了data_info,就可以返回上面的__getitem__()函数中的self.data_info[index],根据index索取图片和标签
注意:构建了两个Dataset
,一个用于训练,一个用于验证
Dataloader构建代码:
# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
步骤 | 源码实现 |
---|---|
读哪些数据 | sampler.py输出的Index |
从哪读数据 | Dataset中的参数data_dir |
怎么读数据 | Dataset的getitem()实现根据索引去读取数据 |
for
循环中去使用DataLoader
,进入DataLoader
之后是否采用多进程进入DataLoaderlter
,进入DataLoaderIter
之后会使用sampler
去获取Index
,拿到索引之后传输到DatasetFetcher
,在DatasetFetcher
中会调用Dataset
,Dataset
根据给定的Index
,在getitem
中从硬盘里面去读取实际的Img
和Label
,读取了一个batch_size
的数据之后,通过一个collate_fn
将数据进行整理,整理成batch_Data
的形式,接着就可以输入到模型中训练Sampler
决定的,从哪读是由Dataset
决定的,怎么读是由getitem
决定的