PyTorch学习—6.PyTorch数据读取机制Dataloader与Dataset

文章目录

      • 一、PyTorch数据读取机制Dataloader

一、PyTorch数据读取机制Dataloader

  PyTorch数据读取在Dataloader模块下,Dataloader又可以分为DataSet与Sampler。Sampler模块的功能是生成索引(样本序号);DataSet是依据索引读取Img、Lable。我们主要学习Dataloader与Dataset。
  torch.utils.data.DataLoader()

DataLoader(dataset,
			batch_size=1,shuffle=False,sampler=None,
			batch_sampler=None,num_workers=0,
			collate_fn=None,pin_memory=False,drop_last=False,timeout=0,
			worker_init_fn=None,
			multiprocessing_context=None)

功能:构建可迭代的数据装载器

  • dataset: Dataset类,决定数据从哪读取及如何读取
  • batch_size :批大小
  • num_works:是否多进程读取数据
  • shuffle:每个epoch是否乱序
  • drop_last :当样本数不能被batchsize整除时,是否舍弃最后一批数据

Epoch:所有训练样本都已输入到模型中,称为一个Epoch
Iteration:一批样本输入到模型中,称之为一个lteration
Batchsize:批大小,决定一个Epoch有多少个lteration
样本总数:80,Batchsize : 8
1 Epoch = 10 lteration
样本总数:87, Batchsize: 8
1 Epoch = 10 lteration ? drop_last = True
1 Epoch = 11 lteration drop_last = False

  torch.utils.data.Dataset()

class Dataset(object):
	def __getitem__(self,index):
		raise NotImplementedError
	def __add__(self, other) :
		return ConcatDataset([self, other])

功能: Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__()
getitem:接收一个索引,返回一个样本
数据读取流程如下:
PyTorch学习—6.PyTorch数据读取机制Dataloader与Dataset_第1张图片

for i, data in enumerate(train_loader):

==>

# 判断是单进程还是多进程
    def __iter__(self):
    	# 单进程
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        # 多进程
        else:
            return _MultiProcessingDataLoaderIter(self)

==>
# 以单进程为例
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self.timeout == 0
        assert self.num_workers == 0

        self.dataset_fetcher = _DatasetKind.create_fetcher(
            self.dataset_kind, self.dataset, self.auto_collation, self.collate_fn, self.drop_last)
	
	# 这个函数告诉我们每个iteration中读哪些数据
    def __next__(self):
    	# 
        index = self._next_index()  # may raise StopIteration
        data = self.dataset_fetcher.fetch(index)  # may raise StopIteration
        if self.pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

    next = __next__  # Python 2 compatibility


==>

    def _next_index(self):
        return next(self.sampler_iter)  # may raise StopIteration


==>

	# 利用sampler输出的index来进行采样
    def __iter__(self):
        batch = []
        # 
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

==>

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
        	# 这一步实现了正式的数据读取
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)


==>


class RMBDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        rmb面额分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform,数据预处理
        """
        self.label_name = {"1": 0, "100": 1}
        self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.transform = transform

    def __getitem__(self, index):
    	# 根据索引index获得数据与标签
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')     # 0~255

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.data_info)

    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        # 遍历一个目录内,各个子目录与子文件
        for root, dirs, _ in os.walk(data_dir):
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))

                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = rmb_label[sub_dir]
                    data_info.append((path_img, int(label)))

        return data_info


==>


class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        # 数据的整理器,将读取到的数据整理成batch的形式
        return self.collate_fn(data)

==>


    for i, data in enumerate(train_loader):

        # forward
        # data由两个Tensor组成
        inputs, labels = data

数据整理器将数据由下面的形式:PyTorch学习—6.PyTorch数据读取机制Dataloader与Dataset_第2张图片
转化为batch形式:
在这里插入图片描述


如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论或者私信!
在这里插入图片描述


你可能感兴趣的:(PyTorch框架学习,pytorch,数据读取机制,Dataloader)