Python:PyTorch 数据加载 torch.utils.data.DataLoader()

使用基于 PyTorch 构建的模型进行训练前,需要对数据进行加载操作

即使用 torch.utils.data.DataLoader()

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, *, prefetch_factor=2, persistent_workers=False)

对数据 dataset 进行加载,同时提供可分批加载功能,即设置 batch_size

dataset:需要加载的数据集

batch_size:默认1,batch的大小

shuffle:默认 False,是否在每个epoch重新打乱样本顺序

sampler:默认None,定义从数据集中获取样本的策略,设定此项则忽略 shuffle

num_workers:默认0,加载使用的进程的数量,0表示在主进程中加载数据

collate_fn:合并一组样本以形成张量的mini-batch(从map-style的数据集中分批加载数据时使用)

pin_memory:默认False,若为True,数据加载器将复制张量到CUDA固定内存中

drop_last:默认False,是否删除最后一个不完整的 batch;假设数据集大小不能被 batch_size 整除,为Ture是将删除最后一个小batch,为 False 则保留

timeout:默认0,从工作进程收集一个batch的延迟值,应始终非负

worker_init_fn: 默认None,如果设置非None,则在设定种子后及数据加载前,将使用工作进程id作为输入在每个工作子进程上调用该函数

prefetch_factor:默认2,每个工作子进程预先加载的样本数量,总的预先加载样本数为 prefetch_factor * num_workers

persistent_workers:默认False,是否在使用数据集后关闭工作进程

使用该函数的关键还是传入合适的 dataset,设置合适的 batch_size,比如:

train_loader = Data.DataLoader(dataset=data_train_TD, batch_size=64, shuffle=True)

之后进行模型参数设定及训练即可

你可能感兴趣的:(PyTorch,Python,python)