【PyTorch】使用DataLoader自定义数据集读取

【PyTorch】使用DataLoader自定义数据集读取

为了方便之后使用PyTorch的distributed部署,加速训练,将数据读取的方式改为适配pytorch提供的Dataset和DataLoader的方式。这里记录一下修改的要点:

1. 涉及的import库:

import torch
from torch.utils.data import Dataset, DataLoader

2. 自定义一个Dataset类:

  • 该类继承Dataset;

  • 可以定义若干个数据预处理的函数,关键的两个函数是:__len__()__getitem__();

  • __getitem__()实际是python支持的一个迭代器函数,编写时每次返回一个sample,不需要定义batch size,之后的DataLoader会自动帮忙读取数据组成batch的;

  • 举个栗子:

    class MyDataset(Dataset):
    	def __init__(self,data):
    		self.data = data
    	def __len__(self):
    		return len(self.data)
    	def __getitem__(self):
    		return self.data
    	def output(self):
    		print('output')
    

3. 初始化Dataset和DataLoader类:

  • DataLoader的参数可参考:https://blog.csdn.net/zyq12345678/article/details/90268668

  • 注意,如果在Dataset中每次返回的是自己定义的数据类型,或者是字典类型,有时要自己编写collate_fn()函数,告诉系统如何返回一个batch。

  • 举个栗子:

    dataset = MyDataset(data)
    dataloader = DataLoader(
        dataset,
        batch_size = 2,
        num_workers = 8,
        collate_fn = collate_fn,
        pin_memory = True
    )
    # 返回数据结构较复杂,包括自定义数据类型或字典时
    def collate_fn(batch):
        data = list(batch)
        return (data)
    
  • 如果遇到类似报错:

    TypeError: can't pickle _thread._local objects

    请将DataLoader中的num_workers参数设置为0,关闭多线程。原因可能是无法自动多线程处理复杂的数据类型。

4. 访问Dataloader内的Dataset类函数

  • 举个栗子:
for step, batch in enumerate(dataloader):
	dataloader.dataset.output()

你可能感兴趣的:(python,机器学习,pytorch,python,深度学习)