简而言之,这俩就是自动帮我们取数据,避免了接触底层代码
机器学习模型训练五大步骤;第一是数据,第二是模型,第三是损失函数,第四是优化器,第五个是迭代训练过程。
这里主要学习数据模块当中的数据读取,数据模块通常还会分为四个子模块:数据收集、数据划分、数据读取、数据预处理。
在进行实验之前,需要收集数据,数据包括原始样本和标签;
有了原始数据之后,需要对数据集进行划分,把数据集划分为训练集、验证集和测试集;训练集用于训练模型,验证集用于验证模型是否过拟合,也可以理解为用验证集挑选模型的超参数,测试集用于测试模型的性能,测试模型的泛化能力;
第三个子模块是数据读取,也就是即将要学习的DataLoader,pytorch中数据读取的核心就是DataLoader;
第四个子模块是数据预处理,把数据读取进来往往还需要对数据进行一系列的图像预处理,比如数据的中心化,标准化,旋转或者翻转等等。pytorch中数据预处理是通过transforms进行处理的;
详情请见原文链接。
经过debug实践和总结后,如下。
(1)torch.utils.data.DataLoader
功能:构建可迭代的数据装载器;
dataset: Dataset类,决定数据从哪里读取及如何读取; batchsize:批大小;
num_works:
是否多进程读取数据;
shuffle: 每个epoch是否乱序;
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据;
如
DataLoader( dataset = dataset , #dataset 是继承了dataset类之后加载数据集提供路径
batch_size = 32, #选择batch_size的大小
shuffle = true, #增强数据集随机性
num_workers = 2 ) #多进程读数据
再次强调
(2)torch.utils.data.Dataset
Dataset是用来定义数据从哪里读取,以及如何读取的问题;
功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__();
getitem:接收一个索引,返回一个样本
下面对人民币二分类的数据进行读取,从三个方面了解pytorch的读取机制,分别为读哪些数据、从哪读数据、怎么读数据;
具体来说,在每一个Iteration的时候应该读取哪些数据,每一个Iteration读取一个Batch大小的数据,假如有80个样本,那么从80个样本中读取8个样本,那么应该读取哪八个样本,这就是我们的第一个问题,读哪些数据;
意思是在硬盘当中,我们应该怎么找到对应的数据,在哪里设置参数;
从代码中学习可以发现,数据的获取是从DataLoader迭代器中不停地去获取一个Batchsize大小的数据,通过for循环获取的;
下面开始debug调试看读取数据的过程。
首先在pycharm中对
for i, data in enumerate(train_loader):
这一行代码设置断点,然后执行Debug,然后点击步进功能键,就可以跳转到对应的函数中,可以发现是跳到了dataloader.py中的__iter__()函数;具体如下所示:
def __iter__(self):
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
return _MultiProcessingDataLoaderIter(self) #进程问题
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self,loader):
super(_SingleProcessDataLoaderIter,self).__init__(loader)
assert self.timeout == 0
assert self.num_workers == 0
self.dataset_fetcher = _DatasetKind.create_fetcher(self.dataset_kind, self.dataset,self.auto_collation, self.collate_fn, self.drop_last)
def __next__(self):
index = self._next_index() # may raise StopIteration
data = self.dataset_fetcher.fetch(index) # may raise StopIteration
if self.pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
next = __next__ # Python 2 compatibility
def _next_index(self):
return next(self.sampler_iter) # may raise StopIteration
Index={list}
: [4, 135, 113, 34, 47, 140, 87, 0, 59, 33, 144, 43, 83, 133, 1, 78] #That’s explained everything!
self={_SingleProcessDataLoaderlter}
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index] #调用了dataset,通过一系列的data拼接成一个list;
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
def __getitem__(self, index):
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB') # 0~255
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
这里已经实现了data_info()函数,即对数据进行了初步的读取,已经得到了图片的路径和标签的列表了(回答了问题2!),再把index相应的值读出来即可(
关于如何加载自定义数据集见这个博客,这里就解释了博客里讲的为什么要分两步,第一步制作标签索引的txt文件,第二步写Dataset类的getitem函数);然后通过Image.open实现了一个数据的读取(回答了问题3!);
之后点击step_out跳出该函数,会返回fetch()函数中;
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
通过以上的分析,可以回答一开始提出的数据读取的三个问题:1、读哪些数据;2、从哪读数据;3、怎么读数据;
(1)从代码中可以发现,index是从sampler.py中输出的一个列表,所以读哪些数据是由sampler得到的;
(2)从代码中看,是从Dataset中的参数data_dir,告诉我们pytorch是从硬盘中的哪一个文件夹获取数据;
如
train_data = MyDataset(txt='../gender/train1.txt',type = "train", transform=transform_train) #Dataset类是自己写的,传进去的data_dir即"txt='../gender/train1.txt' "参数
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=train_sampler)
(3)从代码中可以发现,pytorch是从Dataset的getitem()中具体实现的,根据索引去读取数据;
Dataloader读取数据很复杂,需要经过四五个函数的跳转才能最终读取数据
为了简单,通过流程图来对数据读取机制有一个简单的认识;
学习完这里真的太不容易了,但搞清楚一个重要的事情的来龙去脉还是很有成就感的!继续加油!
by 小李
如果你坚持到这里了,请一定不要停,山顶的景色更迷人!好戏还在后面呢。加油!
欢迎交流学习和批评指正,你的点赞和留言将会是我更新的动力!谢谢