Pytorch数据加载模块:Dataset,Sampler和DataLoader总结

官网教程示例:

https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

Pytorch加载数据三步走:


  1. Dataset:解析单个样本,把数据映射成(x,y)的形式;

    • map-style:实现__getitem__和__len__接口,随机取数据代价小(大多数情况用map-stype);
    • iterable-style:实现__iter__接口,随机取数据代价大,适合处理流数据(比如文本流数据);
  2. Sampler:提供一种遍历数据集所有元素索引的方式,有默认值;

  3. DataLoader:将当个样本变成训练时需要的batch形式;


1 Dataset

1.1 源码

# 接口
from torch.utils.data import Dataset

# 源码位置
# ../torch/utils/data/dataset.py

# 查看torch安装位置
import torch
print(torch.__file__)

源码

# Dataset抽象类 对外暴露一些接口

# map-style
class Dataset(Generic[T_co]):

    def __getitem__(self, index) -> T_co:
    	# 基类中没有实现 需要自己实现
        raise NotImplementedError

    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])
# iter-style
class IterableDataset(Dataset[T_co]):

    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError

    def __add__(self, other: Dataset[T_co]):
        return ChainDataset([self, other])

1.2 创建自己的Dataset

定义自己的Dataset,继承Dataset类后,需要(必须)实现三个方法:

  • _init_
  • _len_
  • _getitem_

示例:

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        # 保存图像的根路径
        self.img_dir = img_dir
        # 对数据的处理 数据增强之类的
        self.transform = transform
        # 对标签的处理
        self.target_transform = target_transform

    def __len__(self):
        # 返回一共有多少个数据
        return len(self.img_labels)

    def __getitem__(self, idx):
        # 拼凑图像的完整路径
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        # 读取图像
        image = read_image(img_path)
        # 从csv中读取的信息分割出标签
        label = self.img_labels.iloc[idx, 1]
        # 对数据及标签进行处理
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label
    
# csv保存图片名 大概长这样
# tshirt1.jpg, 0
# tshirt2.jpg, 0
# ......
# ankleboot999.jpg, 9

1.3 加载数据集

ann_csv = '../ann.csv'
img_root = '/'

# 实例化
myDataset = CustomImageDataset(ann_csv,img_root)

# 获取该类的属性
print(myDataset.img_dir)

# 获取数据的数量 可以用 但是一般不这么用
print(myDataset.__len__())

# 获取第1个数据的img和label(下标0)
# 可以用 但是一般不这么用
img,lab = myDataset.__getitem__(0)
print(img.shape, lab)

# 一般这么用...
print(len(myDataset))
img,lab = myDataset[0]
print(img.shape, lab)


# 一般不会单独用Dataset
# 扔到DataLoader里 构成batch数据

1.4 Dataset的子类

1.4.1 TensorDataset

如果数据本身已经是tensor形式了

# 数据转为tensor格式
x_train, y_train = torch.tensor(x_train), torch.tensor(y_train)

# 直接用TensorDataset封装即可
train_dataset = TensorDataset(x_train, y_train)

1.4.2 IterableDataset

根据两个数start和end生成数据集;

# 继承IterableDataset
class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, start, end):
        super(MyIterableDataset).__init__()
        assert end > start, "this example code only works with end >= start"
        self.start = start
        self.end = end
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading, return the full iterator
            iter_start = self.start
            iter_end = self.end
        else:  # in a worker process
            # split workload
            per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
        # !核心就是根据range生成的数
        return iter(range(iter_start, iter_end))
# 实例化
# 结果:[3, 4, 5, 6].
ds = MyIterableDataset(start=3, end=7)

# 用DataLoader 单线程进行加载
# [tensor([3]), tensor([4]), tensor([5]), tensor([6])]
print(list(torch.utils.data.DataLoader(ds, num_workers=0)))

# 用DataLoader 多线程进行加载
print(list(torch.utils.data.DataLoader(ds, num_workers=2)))

1.4.3 ConcatDataset

将多个数据集拼接成一个;
用法如下:

# 第一个数据集  len 60000
mnist_data = MNIST('./data', train=True, download=True)

# 第二个数据集 len 50000
cifar10_data = CIFAR100('./data', train=True, download=True)

# 两个数据集拼接 len 110000
concat_data = ConcatDataset([mnist_data, cifar10_data])

1.4.4 ChainDataset

将IterableDataset类的多个数据集拼接成一个数据集;

1.4.5 Subset

将一个数据集划分为子数据集,比如划分训练集和验证集;

# 训练集和验证集的索引
train_indices, val_indices = indices[split:], indices[:split]

# 根据索引随机划分
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

2 Sampler

就是遍历数据集的方式,默认方式有两种:

  • shuffle = True:sampler = RandomSampler(dataset, generator=generator),随机打乱;
  • shuffle = False:sampler = SequentialSampler(dataset),不打乱;
  • 也可以自定义Sampler传入,但是Sampler与shuffle互斥;

2.1 RandomSampler

class RandomSampler(Sampler[int]):
    data_source: Sized
    replacement: bool

    def __init__(self, data_source: Sized, replacement: bool = False,
                 num_samples: Optional[int] = None, generator=None) -> None:
                 
        # slef. = ...


    def __iter__(self) -> Iterator[int]:
        n = len(self.data_source)

        if self.replacement:
            for _ in range(self.num_samples // 32):
                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
        else:
            for _ in range(self.num_samples // n):
                # 核心就是torch.randperm函数
                # 生成0~n-1的随机数列(索引)
                yield from torch.randperm(n, generator=generator).tolist()
            yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]

    def __len__(self) -> int:
        return self.num_samples

2.2 SequentialSampler

SequentialSampler其实什么也没做,不破坏数据集原有的顺序;

class SequentialSampler(Sampler[int]):

    def __init__(self, data_source: Sized) -> None:
        self.data_source = data_source

    def __iter__(self) -> Iterator[int]:
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)

2.3 自定义Sampler

import random
from torch.utils.data.sampler import Sampler
 
 # 自定义必须先继承Sample类
 # 必须实现__init__,__iter__,__len__方法
class MySampler(Sampler):
    def __init__(self, dataset):
    	# 将数据集均分为两部分
        halfway_point = int(len(dataset)/2)
        self.first_half_indices = list(range(halfway_point))
        self.second_half_indices = list(range(halfway_point, len(dataset)))
        
    def __iter__(self):
    	# 每次从前一半和后一半各返回一个
    	# 假设前一半为 1 2 3 4 5 
    	#    后一半为 6 7 8 9 10
    	# 则依次返回(1,6)(2,7)(3,8)...
        random.shuffle(self.first_half_indices)
        random.shuffle(self.second_half_indices)
        return iter(self.first_half_indices + self.second_half_indices)
    
    def __len__(self):
		return len(self.first_half_indices) + len(self.second_half_indices)

3 DataLoader

3.1 使用DataLoader

from torch.utils.data import DataLoader

training_data = myDataset(...)
test_data = myDataset(...)
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
# 一般测试集不打乱 没有意义
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)

train_features, train_labels = next(iter(train_dataloader))
# Feature batch shape: torch.Size([64, 1, 28, 28])
print(f"Feature batch shape: {train_features.size()}")
# Labels batch shape: torch.Size([64])
print(f"Labels batch shape: {train_labels.size()}")

# ... 一些处理

3.2 源码及参数

# DataLoader源码位置
# /torch/utils/data/dataloader.py

# 参数们
# dataset: Dataset实例对象
# batch_size:批量大小 默认为1
# shuffle:每周期后是否对数据进行打乱
# sampler:遍历数据集的方式 有默认值 和shuffle互斥
# batch_sampler:同上 和shuffle sampler drop_last batch_size互斥
# num_workers:默认为0 加载数据(batch)的进程数目
# num_workers的经验设置值是自己电脑/服务器的CPU核心数
# 0意味着所有的数据都会被load进主进程
# collate_fn: 对batch数据再处理
# pin_memory: 锁页内存 数据放到GPU上
# drop_last: 非整数batch时 最后一个batch丢掉
# timeout: 如果是正数,表明等待从worker进程中收集一个batch等待的时间
# 若超出设定的时间还没有收集到,那就不收集这个内容了
class DataLoader(Generic[T_co]):
    dataset: Dataset[T_co]
    batch_size: Optional[int]
    num_workers: int
    pin_memory: bool
    drop_last: bool
    timeout: float
    sampler: Union[Sampler, Iterable]
    prefetch_factor: int
    _iterator : Optional['_BaseDataLoaderIter']
    __initialized = False
   

    def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
                 shuffle: bool = False, sampler: Union[Sampler, Iterable, None] = None,
                 batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
                 num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
                 pin_memory: bool = False, drop_last: bool = False,
                 timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
                 multiprocessing_context=None, generator=None,
                 *, prefetch_factor: int = 2,
                 persistent_workers: bool = False):
 		
        # 一堆成员变量设置
		# self. = ...
	
	# sampler的设置
    if sampler is None:  
            if self._dataset_kind == _DatasetKind.Iterable:
                sampler = _InfiniteConstantSampler()
            else:  # map-style
                if shuffle:
                    # 原理:通过torch.randperm实现 打乱
                    sampler = RandomSampler(dataset, generator=generator)  
                else:
                    # 原理:iter(range()) 有序
                    sampler = SequentialSampler(dataset)  
                   
	# 在__iter__调用
	# 复写基类方法 实现iter函数
    # 可以调用为iter(train_dataloader)
    def _get_iterator(self) -> '_BaseDataLoaderIter':
        if self.num_workers == 0:
            # 获取下一个索引 根据索引获得并返回数据
            return _SingleProcessDataLoaderIter(self)
        else:
            self.check_worker_number_rationality()
            # 多进程进行处理
            return _MultiProcessingDataLoaderIter(self)

        
  
	# 变成迭代器
    def __iter__(self) -> '_BaseDataLoaderIter':
        if self.persistent_workers and self.num_workers > 0:
            if self._iterator is None:
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else:
            return self._get_iterator()


	
    # 在_BaseDataLoaderIter类调用
    # 其实就是复写基类方法,实现next函数
    # next(iter(train_dataloader))
    @property
    def _index_sampler(self):
        if self._auto_collation:
            return self.batch_sampler
        else:
            return self.sampler
        

    # 返回有多少batch
    def __len__(self) -> int:
        # ...

    # 对num_workers设定合理性进行检查
    def check_worker_number_rationality(self):
        # ...

你可能感兴趣的:(pytorch,python,深度学习)