torch的DataLoader 浅析

torch的DataLoader主要是用来装载数据,就是给定已知的数据集,把数据集装载进DataLoaer,然后送入深度学习网络进行训练。先看一下它的声明吧。(官方声明,pytorch 1.10.0文档,见参考资料1)

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

其中的参数如下:

  • dataset (Dataset) – dataset from which to load the data.(数据类型:DataSet,需要装载进DataLoader的原始数据集)

  • batch_size (intoptional) – how many samples per batch to load (default: 1).(数据类型:int,可选项,批的大小,默认为1)

  • shuffle (booloptional) – set to True to have the data reshuffled at every epoch (default: False).(数据类型:bool,可选项,每个循环是否需要重新打乱或洗牌)

  • sampler (Sampler or Iterableoptional) – defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified.(可选项,定义了从数据集获取样本的策略,可以是任何实现了__len__的迭代器类型(Iterable),如果使用了这个选项,shuffle不可再设置

  • batch_sampler (Sampler or Iterableoptional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_sizeshufflesampler, and drop_last.(数据类型Sampler 或者Iterable,可选项,就像sampler一样,但是它一次返回一批的索引,使用后,不可使用batch_sizeshufflesampler, drop_last选项

  • num_workers (intoptional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)(数据类型:int,可选项,确定使用多少个子进程来进行数据加载,0代表使用主进程加载。默认为0)

  • collate_fn (callableoptional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.(数据类型:可调用的类型,可选项,合并一个链表的样本来形成最小批的张量,当从映射类型数据集装载的时候使用)

  • pin_memory (booloptional) – If True, the data loader will copy Tensors into CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.(数据类型:bool,可选项,如果设置为True,dataloader 会在返回张量前将其拷贝至CUDA的pinned的内存区。如果你的数据类型是一个个性化的类型或者你的collate_fn返回了个性化的一批样本,请参看下面例子

  • drop_last (booloptional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)(数据类型:bool,可选项,若设置为True,如果数据集不能被批除尽,则舍弃最后一个不满足整批的样本,若设置为False,如果数据集不能被批除尽,则保留最后一个不满足整批的样本。)

  • timeout (numericoptional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)(数据类型:数字,可选项,如果是正值,此选项代表从一个进程收集批数据的超时时间,应该是非负值)

  • worker_init_fn (callableoptional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)(数据类型:可调用的类型,可选项,在进行随机种子之后,且装载数据之前,这个可调用类型在调用之后,它的输出作为每一个子进程的输入,子进程的ID为[0, num_workers - 1]

  • generator (torch.Generatoroptional) – If not None, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate base_seed for workers. (default: None)(数据类型:torch.Generator,可选项,如果此选项不是None,RandomSampler 将使用RNG去生成随机的索引,在多进程中也会生成每个进程基准种子)

  • prefetch_factor (intoptionalkeyword-only arg) – Number of samples loaded in advance by each worker. 2 means there will be a total of 2 * num_workers samples prefetched across all workers. (default: 2)(数据类型:int,可选项,只是用关键字的参数,对每个进程来说,需要提前装载的样本数量。当此值为2时,对所有的进程来数,需要提前获取的样本数量为2*num_workers,默认值为2)

  • persistent_workers (booloptional) – If True, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)(数据类型bool,可选项,如果设置为True,对每个进程来数,当一个数据集被使用一次后,此进程并不会被关闭,这样就会保持进程中的数据集实例是活的)

具体来看dataset的类型,

Dataset Types

The most important argument of DataLoader constructor is dataset, which indicates a dataset object to load data from. PyTorch supports two different types of datasets:

  • map-style datasets,

  • iterable-style datasets.

Map-style datasets

A map-style dataset is one that implements the __getitem__() and __len__() protocols, and represents a map from (possibly non-integral) indices/keys to data samples.

For example, such a dataset, when accessed with dataset[idx], could read the idx-th image and its corresponding label from a folder on the disk.

See Dataset for more details.

Iterable-style datasets

An iterable-style dataset is an instance of a subclass of IterableDataset that implements the __iter__() protocol, and represents an iterable over data samples. This type of datasets is particularly suitable for cases where random reads are expensive or even improbable, and where the batch size depends on the fetched data.

For example, such a dataset, when called iter(dataset), could return a stream of data reading from a database, a remote server, or even logs generated in real time.

See IterableDataset for more details.

NOTE

When using an IterableDataset with multi-process data loading. The same dataset object is replicated on each worker process, and thus the replicas must be configured differently to avoid duplicated data. See IterableDataset documentations for how to achieve this.

DataLoader中最重要的参数就是dataset,它决定了要装载的数据集。torch支持两种类型的数据集

(1)map-style 类型。

一个map-style类型是实现了__getitem__() 和__len__()协议的类,它代表了一个从索引/键值 到数据样本的映射。

例如,对于一个通过dataset[idx]访问的数据集,可以读到第idx个图片,并从磁盘的文件中取到对应的标签。

看第一个例子吧(见参考资料4)

import torch
from torch.utils.data import DataLoader
import numpy as np

class MyLoader(torch.utils.data.Dataset):

#父类是torch.utils.data.Dataset,也可以是object,对父类没有要求
    def __init__(self,data,label):
        self.data=data
        self.label=label
    def __getitem__(self,index):#迭代数据
        data=self.data[index]
        labels=self.label[index]
        return data,labels
    def __len__(self):#返回数据的总长度
        return len(self.data)
source_data=np.random.rand(10,20)
source_label=np.random.randint(0,2,(10,1))

torch_data=MyLoader(source_data,source_label)

for i,data in enumerate(torch_data):
    print('第{}个 Batch {}'.format(i,data))

针对图像分割,数据集是飞浆平台课程上的(参考资料3),数据集原图片,而标签是和原数据集大小一样的图片,如下图所示:

原图片:

torch的DataLoader 浅析_第1张图片

而标签为

torch的DataLoader 浅析_第2张图片

原图片和标签都是经过转换之后的。而list文件如下,共14个样本:

humanseg/aa645bc9cf23db7912a69309072cd9ab325f02cd.jpg visual/aa645bc9cf23db7912a69309072cd9ab325f02cd.png
humanseg/aa63d7e6db0d03137883772c246c6761fc201059.jpg visual/aa63d7e6db0d03137883772c246c6761fc201059.png
humanseg/aa6300f76981dcf8701534dd1d3b2ec19b3dee02.jpg visual/aa6300f76981dcf8701534dd1d3b2ec19b3dee02.png
humanseg/56173ddd1ccb419e1efdeb5f5cb242ab160142cb.jpg visual/56173ddd1ccb419e1efdeb5f5cb242ab160142cb.png
humanseg/aa6bd3eaf471bea1cca7467a95fe93e69b006797.jpg visual/aa6bd3eaf471bea1cca7467a95fe93e69b006797.png
humanseg/aa67b2d074e00942191c4bd2472e7f77538ec113.jpg visual/aa67b2d074e00942191c4bd2472e7f77538ec113.png
humanseg/aa6ff076c7360b8dabc30edd05ebafb65bba9343.jpg visual/aa6ff076c7360b8dabc30edd05ebafb65bba9343.png
humanseg/aa611a0cf92ace38bd2d3b0fe0bc50b5235eea7e.jpg visual/aa611a0cf92ace38bd2d3b0fe0bc50b5235eea7e.png
humanseg/aa65f5b4f85c37ce44dc48473150a16e652b6bc5.jpg visual/aa65f5b4f85c37ce44dc48473150a16e652b6bc5.png
humanseg/aa65c231dbce73de1527101bf35b975b2c2e9d5a.jpg visual/aa65c231dbce73de1527101bf35b975b2c2e9d5a.png
humanseg/aa6b34b24414bafa7fab8393239c793587513ce6.jpg visual/aa6b34b24414bafa7fab8393239c793587513ce6.png
humanseg/aa662fb7540312c51f6e6870c0542c8035495b14.jpg visual/aa662fb7540312c51f6e6870c0542c8035495b14.png
humanseg/aa6f23e6ac596962ee773e4eea0560fb0e4522ac.jpg visual/aa6f23e6ac596962ee773e4eea0560fb0e4522ac.png
humanseg/aa65dc40ae9713e4fe3e63b55a8fd10bd1320822.jpg visual/aa65dc40ae9713e4fe3e63b55a8fd10bd1320822.png

每一行中的第一个是原始文件,第二个是标签文件 ,开发平台linux平台,python 版本3.7.4,anaconda3,torch版本1.10.0+cpu

class Transform(object): #图片转换
    def __init__(self,size=256):
        self.size=size
    def __call__(self,input,label):
        input=cv2.resize(input,(self.size,self.size),interpolation=cv2.INTER_LINEAR)
        label=cv2.resize(label,(self.size,self.size),interpolation=cv2.INTER_NEAREST)

        return input,label
#map-style datasets
class MapDataLoader(torch.utils.data.Dataset):
    def __init__(self,image_folder,image_list_file,transform=True,shuffle=True):
        self.image_folder=image_folder
        self.image_list_file=image_list_file
        self.transform=transform
        self.shuffle=shuffle
        
        self.data_list=self.read_list() #读取列表
        self.data_total=self.get_total() #获取所有的数据集,包括原数据和标签,放入列表
        
    def __getitem__(self,index):
        data=self.data_total[index][0]
        labels=self.data_total[index][1]
        return data,labels
    def read_list(self):
        data_list=[]
        with open(os.path.join(self.image_folder,self.image_list_file)) as infile:
            for line in infile:
                data_path=os.path.join(self.image_folder,line.split()[0])
                label_path=os.path.join(self.image_folder,line.split()[1])
                data_list.append((data_path,label_path))
        random.shuffle(data_list)
        return data_list
    def get_total(self):
        total_list=[]
        for data_path,label_path in self.data_list:
            data=cv2.imread(data_path,cv2.IMREAD_COLOR)
            label=cv2.imread(label_path,cv2.IMREAD_GRAYSCALE)
            assert data.all!=None,"NoneType"
            print(data.shape,label.shape)
            data,label= self.preprocess(data,label)
            print('after:',data.shape,label.shape)
            total_list.append((data,label))
        random.shuffle(total_list)
        return total_list
    def preprocess(self,data,label):
        h,w,c=data.shape
        h_gt,w_gt=label.shape
        assert h==h_gt,"Error"
        assert w==w_gt,"Error"
        if self.transform:
            data,label=self.transform(data,label)
        
        label=label[:,:,np.newaxis] #扩展一维

        return data,label
        

    def __len__(self):
        return len(self.data_total)
transform=Transform(256)
map_dataloader=MapDataLoader(
            image_folder='../data',
            image_list_file='list_linux.txt',
            transform=transform,
            shuffle=True
        )

输出结果如下:

(1000, 706, 3) (1000, 706)
after: (256, 256, 3) (256, 256, 1)
(664, 1000, 3) (664, 1000)
after: (256, 256, 3) (256, 256, 1)
(768, 484, 3) (768, 484)
after: (256, 256, 3) (256, 256, 1)
(1000, 666, 3) (1000, 666)
after: (256, 256, 3) (256, 256, 1)
(960, 717, 3) (960, 717)
after: (256, 256, 3) (256, 256, 1)
(940, 626, 3) (940, 626)
after: (256, 256, 3) (256, 256, 1)
(600, 900, 3) (600, 900)
after: (256, 256, 3) (256, 256, 1)
(565, 800, 3) (565, 800)
after: (256, 256, 3) (256, 256, 1)
(633, 940, 3) (633, 940)
after: (256, 256, 3) (256, 256, 1)
(825, 550, 3) (825, 550)
after: (256, 256, 3) (256, 256, 1)
(634, 950, 3) (634, 950)
after: (256, 256, 3) (256, 256, 1)
(939, 626, 3) (939, 626)
after: (256, 256, 3) (256, 256, 1)
(1000, 737, 3) (1000, 737)
after: (256, 256, 3) (256, 256, 1)
(676, 1000, 3) (676, 1000)
after: (256, 256, 3) (256, 256, 1)
datas=DataLoader(map_dataloader,batch_size=4,shuffle=True,drop_last=False,num_workers=4)#windows下num_workers 需要设置为0
num_epoch=2
for epoch in range(1,num_epoch+1):
    print(f'Epoch [{epoch}/{num_epoch}]')
    for index,(data,label) in enumerate(datas):
        print(f'Iter {index},data shape:{data.shape} Label shape:{label.shape}')

输出结果:

Epoch [1/2]
Iter 0,data shape:torch.Size([4, 256, 256, 3]) Label shape:torch.Size([4, 256, 256, 1])
Iter 1,data shape:torch.Size([4, 256, 256, 3]) Label shape:torch.Size([4, 256, 256, 1])
Iter 2,data shape:torch.Size([4, 256, 256, 3]) Label shape:torch.Size([4, 256, 256, 1])
Iter 3,data shape:torch.Size([2, 256, 256, 3]) Label shape:torch.Size([2, 256, 256, 1])
Epoch [2/2]
Iter 0,data shape:torch.Size([4, 256, 256, 3]) Label shape:torch.Size([4, 256, 256, 1])
Iter 1,data shape:torch.Size([4, 256, 256, 3]) Label shape:torch.Size([4, 256, 256, 1])
Iter 2,data shape:torch.Size([4, 256, 256, 3]) Label shape:torch.Size([4, 256, 256, 1])
Iter 3,data shape:torch.Size([2, 256, 256, 3]) Label shape:torch.Size([2, 256, 256, 1])

(2)Iterable-style 类型。

这种类型的父类是IterableDataset,并且实现了 __iter__() 协议,代表了在数据样本的迭代器。这种类型非常适合随机读取非常难或者不可能的情况,这种情况下批的大小取决于得到的数据。

例如有这样一个数据集,可以调用iter(dataset),可以从数据库、远程服务器或者实时产生的日志获取样本流。

注意。

当使用多进程加载Iterable-style 类型的DataLoader时,一份样本会被复制至所有的进程中,这份复制的样本将被差异配置以避免重复数据。请参阅IIterableDataset文档以如何实现它。

还是以图像分割为例。

#iter_style
class IterDataLoader(torch.utils.data.IterableDataset):#父类是torch.utils.data.IterableDataset
    def __init__(self,image_folder,image_list_file,transform=True,shuffle=True):
        self.image_folder=image_folder
        self.image_list_file=image_list_file
        self.transform=transform
        self.shuffle=shuffle
        self.data_list=self.read_list()
        self.data_total=self.get_total()
        self.start=0
        self.end=len(self.data_total)
    def read_list(self):
        data_list=[]
        with open(os.path.join(self.image_folder,self.image_list_file)) as infile:
            for line in infile:
                data_path=os.path.join(self.image_folder,line.split()[0])
                label_path=os.path.join(self.image_folder,line.split()[1])
                data_list.append((data_path,label_path))
        random.shuffle(data_list)
        return data_list
    def get_total(self):
        total_list=[]
        for data_path,label_path in self.data_list:
            data=cv2.imread(data_path,cv2.IMREAD_COLOR)
#             data=cv2.cvtColor(data,cv2.COLOR_BAYER_BG2RGB)
            label=cv2.imread(label_path,cv2.IMREAD_GRAYSCALE)
            assert data.all!=None,"NoneType"
            print(data.shape,label.shape)
            data,label= self.preprocess(data,label)
            print('after:',data.shape,label.shape)
            total_list.append((data,label))
        random.shuffle(total_list)
        return total_list
    def preprocess(self,data,label):
        h,w,c=data.shape
        h_gt,w_gt=label.shape
        assert h==h_gt,"Error"
        assert w==w_gt,"Error"
        if self.transform:
            data,label=self.transform(data,label)
        
        label=label[:,:,np.newaxis]

        return data,label
        

    def __len__(self):
        return len(self.data_list)
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:  # single-process data loading, return the full iterator
            return iter(self.data_total) #单进程情况下返回所有的
        else:  # 多进程情况下
            per_worker = int(math.ceil(len(self.data_total)/ float(worker_info.num_workers))) #计算出每个进程需要装载样本的数量
#             print('per_worker:',per_worker)
            worker_id = worker_info.id
#             print('worker_id:{}\n'.format(worker_id))
            iter_start = self.start + worker_id * per_worker
            iter_end = min(iter_start + per_worker, self.end)
#             print('start{}:end{}\n'.format(iter_start,iter_end))
            return iter(self.data_total[iter_start:iter_end])

#torch 读取图片
transform=Transform(256)
iter_dataloader=IterDataLoader(
            image_folder='../data',
            image_list_file='list_linux.txt',
            transform=transform,
            shuffle=True
        )
datas=DataLoader(iter_dataloader,batch_size=4,drop_last=False,num_workers=2)#window下num_workers需要设置为0,且不可以使用shuffle==True
num_epoch=2
for epoch in range(1,num_epoch+1):
    print(f'Epoch [{epoch}/{num_epoch}]')
    for index,(data,label) in enumerate(datas):
        print(f'Iter {index},data shape:{data.shape} Label shape:{label.shape}')

输出结果如下:

Epoch [1/2]
Iter 0,data shape:torch.Size([4, 256, 256, 3]) Label shape:torch.Size([4, 256, 256, 1])
Iter 1,data shape:torch.Size([4, 256, 256, 3]) Label shape:torch.Size([4, 256, 256, 1])
Iter 2,data shape:torch.Size([3, 256, 256, 3]) Label shape:torch.Size([3, 256, 256, 1])
Iter 3,data shape:torch.Size([3, 256, 256, 3]) Label shape:torch.Size([3, 256, 256, 1])
Epoch [2/2]
Iter 0,data shape:torch.Size([4, 256, 256, 3]) Label shape:torch.Size([4, 256, 256, 1])
Iter 1,data shape:torch.Size([4, 256, 256, 3]) Label shape:torch.Size([4, 256, 256, 1])
Iter 2,data shape:torch.Size([3, 256, 256, 3]) Label shape:torch.Size([3, 256, 256, 1])
Iter 3,data shape:torch.Size([3, 256, 256, 3]) Label shape:torch.Size([3, 256, 256, 1])

从输出结果来看,静态的数据还是用map-style类型比较合适。

参考资料:

1 torch.utils.data — PyTorch 1.10.1 documentation

2 pytorch/dataloader.py at master · pytorch/pytorch · GitHub

3飞桨PaddlePaddle-源于产业实践的开源深度学习平台

4Pytorch加载自己的数据集(使用DataLoader加载Dataset)_北国觅梦-CSDN博客

你可能感兴趣的:(python,深度学习,pytorch,深度学习,神经网络)