Pytorch的DataLoader, DataSet, Sampler之间的关系

转载自作者marsggbo
https://www.cnblogs.com/marsggbo/p/11308889.html

在 pytorch 的体系中,数据加载的最终目的使用 Dataloader 处理 dataset 对象,以方便的控制 Batch,Shuffle 等等操作。

Pytorch的DataLoader, DataSet, Sampler之间的关系_第1张图片

class DataLoader(object):
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=default_collate,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None)

初始化参数里有两种sampler:sampler和batch_sampler,都默认为None。前者的作用是生成一系列的index,而batch_sampler则是将sampler生成的indices打包分组,得到一个又一个batch的index

num_worker 负责数据加载的多进程数量。
num_worker设置得大,好处是寻batch速度快,因为下一轮迭代的batch很可能在上一轮/上上一轮…迭代时已经加载好了。坏处是内存开销大,也加重了CPU负担(worker加载数据到RAM的进程是CPU复制的嘛)。

如果num_worker设为0,意味着每一轮迭代时,dataloader不再有自主加载数据到RAM这一步骤(因为没有worker了),而是先在RAM中找batch,找不到时再加载相应的batch。缺点当然是速度更慢。

collate_fn 如何取样本的,我们可以定义自己的函数来准确地实现想要的功能

drop_last 告诉如何处理数据集长度除于batch_size余下的数据。True就抛弃,否则保留

pin_memory pin_memory=True,那么Pytorch会采取一系列操作把数据拷贝到GPU,总之就是为了加速。

worker_init_fn 子进程导入模式,默认为Noun。在数据导入前和步长结束后,根据工作子进程的ID逐个按顺序导入数据

你可能感兴趣的:(学习记录)