【Pytorch学习笔记】数据导入

1. 前言

Pytorch的数据导入依靠 torch.utils.data.DataLoader 和torch.utils.data.Dataset(或torch.utils.data.IterableDataset)两个类来实现。

2. torch.utils.data.DataLoader学习

在 torch.utils.data 官方文档中提到,torch.utils.data.DataLoader 是pytorch 数据导入的核心工具,返回一个可迭代对象用于提取数据集中的数据。

DataLoader的参数如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

2.1 Dataset

dataset 参数表示要加载数据的数据集对象。PyTorch支持两种不同类型的数据集:map-style datasets 和 iterable-style datasets。

2.1.1 Map-style datasets

Map-style 数据集应用__getitem()__ 和__len__()两种协议,是一种索引/键值与数据集样本的映射关系。也就是说 dataset 并不是读取了数据,而是读取了数据的索引/键值,后续通过这个索引/键值来访问数据。当我们访问 dataset[idx] 时,可以从磁盘读取第 idx 张图片和与之对应的标签,这一过程是对磁盘上的数据样本随机访问的情况。
其中__getitem()__是通过给定的索引/键值返回相应的数据集样本,len()是返回数据集的大小。

2.1.2 Iterable-style datasets

Iterable-style 数据集应用__iter()__协议,是一个可迭代的数据集。Iterable_style 数据集是读取了数据。这种类型的数据集特别适合于随机读取昂贵甚至不可能的情况,以及批量大小取决于所取数据的情况。
Note:对于多进程加载数据会出现重复读取相同数据情况。
e.g.:

# should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
# Single-process loading
print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
# [3, 4, 5, 6]
# Directly doing multi-process loading yields duplicate data
print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
# [3, 3, 4, 4, 5, 5, 6, 6]

2.1.3 示例

  • 加载CIFAR10数据集
import os
import pickle
import numpy as np

def load_CIFAR_batch(filename):
    '''load single batch of cifar'''
    with open(filename, 'rb') as f:
        datadict = pickle.load(f, encoding='latin1')
        x = datadict['data']
        print(x.shape)
        y = datadict['labels']
        X = x.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
        Y=np.array(y)
        return X, Y

def load_CIFAR10(path):
    '''load all of cifar'''
    xs = []
    ys = []
    for b in range(1,6):
        f = os.path.join(path,'data_batch_%d' % (b, ))
        X, Y = load_CIFAR_batch(f)
        xs.append(X)
        ys.append(Y)
    Xtr = np.concatenate(xs)
    Ytr = np.concatenate(ys)
    del X, Y
    Xte, Yte = load_CIFAR_batch(os.path.join(path, 'test_batch'))
    return Xtr, Ytr, Xte, Yte

Xtr, Ytr, Xte, Yte = load_CIFAR10('D:\\数据集\\cifar-10\\')
  • Map-Style和Iterable-Style
import torch
from torch.utils.data import Dataset, IterableDataset, DataLoader

# Xtr, Xte展开
Xtr_rows = Xtr.reshape(Xtr.shape[0], 32 * 32 * 3)
Xte_rows = Xte.reshape(Xte.shape[0], 32 * 32 * 3)

# 创建子类
class MyMapstyleDataset(Dataset):
    def __init__(self, datas, labels):
        # super(MyMapstyleDataset).__init__()
        self.datas = datas
        self.labels = labels
        
    def __getitem__(self, index):
        data = torch.tensor(self.datas[index])
        label = torch.tensor(self.labels[index])
        return data, label
    
    def __len__(self):
        return len(self.datas)
    
Dataset = MyMapstyleDataset(Xtr_rows, Ytr)
train_dataloader = DataLoader(IterableDataset, shuffle=False, batch_size=8, num_workers=0)


# 创建子类
class MyIterstyleDataset(IterableDataset):
    def __init__(self, start, end):
        super(MyIterstyleDataset).__init__()
        assert end > start, "Error"
        self.start = start
        self.end = end
        # self.filepath = filepath
        
    def _sample_gernerator(self, start, end):
        # 当数据量大,无法一次Load进内存,可以通过这个函数以数据流的形式加载
        for i in range(end-start):
            sample = {"data":torch.tensor(Xtr[start+i, :]), "label":torch.tensor(Ytr[start+i, :])}
            yield sample
        
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:# Single Worker
            iter_start = self.start
            iter_end = self.end
        else:# Multiple Workers
            per_work = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        sample_iterator = self._sample_gernerator(iter_start, iter_end)
        
        return sample_iterator
    
    def __len__(self):
        return self.end - self.start
    

IterableDataset = MyIterstyleDataset(0, len(Xtr_rows))
Iter_train_dataloader = DataLoader(IterableDataset, shuffle=False, batch_size=8, num_workers=0)

2.2 Sampler

Sampler 提供多种数据读取方式。对于 Iterable-Style 数据集,数据读取是自定义的。因此 Sampler 更多的应用在 MapStyle 数据集中。torch.utils.data.Sampler 类生成特定的索引/键值序列用于数据读取,是数据集索引/键值的 Iterable object。

class Sampler(Gerneric[T_co]):
	def __init__(self, data_source) -> None:
		pass
	
	def __iter__(self) -> Iterator[T_co]:
		raise NotImplementedError 

class SequentialSampler(Sampler[int])
'''顺序采样,顺序始终相同.'''

class RandomSampler(Sampler[int])
'''随机采样,"replacement"为True时,可能修改dataset大小,具体看"num_samples"'''

class SubsetRandomSampler(Sampler[int]):
'''按照给定的索引/键值序列采样'''

class WeightedRandomSampler(Sampler[int]):
'''对[0,...,len(weights)-1]的样本,按照weights的概率,随机抽取num_samples个样本'''

简单的 Sequential Sampler 和 shuffled Sampler 使用 DataLoader 的 shuffle 参数即可:True 为 shuffled;False 为 sequential。
同样用户可以自定义 sampler,提供一种__iter()__方法,每次迭代生成下一个索引/键值。

2.3 Automatic batching(batch_sampler)

  • batchsize:每个 batch 包含多少个样本(int,默认:1)
  • drop_last:是否舍弃最后一个不完整的 batch(bool,默认 false)
  • batch_sampler:类似于 sampler,每次生成一个 batch 的索引/键值。
class BatchSampler(Sampler[list[int]]):
	def __init__(self, sampler, batch_size, drop_last):
		self.sampler = sampler
		self.batch_size = batch_size
		self.drop_last = drop_last
	
	def __iter__(self) -> Iterator[List[int]]:
		batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch
	
	def __len__(self) -> int:
		if self.drop_last:
            return len(self.sampler)
        else:
            return (len(self.sampler) + self.batch_size - 1) 

DataLoader 支持将单个获取的数据自动整理为 batch ,其中 batchsize 和 drop_last 参数用于从 sampler 构成 batch_sampler 。
Note: 使用多进程从 Iterable-Style 数据集中读取数据时, drop_last 参数会丢弃每个 worker 的数据集副本的最后一个不完整 batch。

'''读取Map-Style数据集(Automatic batching Enable)'''
for incices in batch_sampler:
	yield collate_fn([dataset[i] for i in indices])

'''读取Iterable-Style数据集(Automatic batching Enable)'''
dataset_iter = iter(dataset)
for indices in batch_sampler:
	yield collate_fn([next(dataset_iter) for _ in indices])

当 batch_size(默认为1,1不是None)和 batch_sampler(默认)都是None时,Automatic batching禁用。

'''读取Map-Style数据集(Automatic batching Disnable)'''
for index in sampler:
	yield collate_fn(dataset[index])

'''读取Iterable-Style数据集(Automatic batching Disnable)'''
for data in iter(dataset):
	yield collate_fn(data)

2.4 collate_fn

当sampler或batch_sampler获取数据样本后,使用collate_fn 函数将样本列表整理成batch。

当 Automatic batching 禁用时: collate_fn 被每个单独的数据样本调用。在这种情况下,默认的 collate_fn 只是将NumPy 数组转换为PyTorch张量。
当 Automatic batching 禁用时: collate_fn 被每个数据样本列表调用。将读取的样本列表整理成batch。
e.g. Dataset 中每个样本是(image, class_index)的元组,collate_fn 返回image tensor和 label tensor。

关于默认的 collate_fn :

  • 增加一个维度表示 batchsize;
  • 将 nparray 转化为 tensor
  • 保留原有的数据结构,比如字典、元组。

用户可以自定义 collate_fn 实现自定义 batching,比如,填充各种长的序列。

3. 参考

[1]. torch.utils.data.DataLoader官方文档
[2]. Pytorch IterableDataset的使用
[3]. 加载CIFAR_10

你可能感兴趣的:(PyTorch学习,pytorch,深度学习,机器学习)