Dataset和Dataloader的使用

目录

需要导入

dataset模板

dataset重写例子

使用方法

dataloader使用


Dataset如果是一叠扑克牌的话,DataLoader就是一只手,参数就是告诉这只手怎么抓取扑克牌。

DataLoader常用参数:
dataset (Dataset) – 使用哪个数据集

batch_size (int, optional) – 一次抓取多少个数据,多少张牌

shuffle (bool, optional) – 是否重新洗牌

num_workers (int, optional) – 多线程,windows下要设成0,不然会出错,出现
 

BrokenPipeError: [Errno 32] Broken pipe

drop_last (bool, optional) – 当数据集大小不能被批大小整除时,设置为True则以删除最后一个不完整的批。False,则不删除,最后一批将变小。(默认值:False)

datasetloader加载批数据过程:

Dataset和Dataloader的使用_第1张图片

需要导入

Dataset(torch.utils.data.Dataset)

dataset模板

class mydataset(Dataset):
    def __init__(self, xxx):
        ...
    
    def __getitem__(self, item):
        ...

    def __len()__(self):
        ...

dataset重写例子



from torch.utils.data import Dataset, DataLoader
import torch

class MyDataset(Dataset):
    """
        下载数据、初始化数据,都可以在这里完成
    """

    def __init__(self,file):
        self.x = torch.linspace(11,20,10)
        self.y = torch.linspace(1,10,10)
        self.len = len(self.x)

    def __getitem__(self, index):
        return self.x[index], self.y[index]

    def __len__(self):
        return self.len



​

使用方法

dataset=mydataset(文件)

dataloader使用

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

train_data = torchvision.datasets.CIFAR10("dataset", train=True, transform=torchvision.transforms.ToTensor())
test_data = torchvision.datasets.CIFAR10("dataset", train=False, transform=torchvision.transforms.ToTensor())

trian_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)

img, target = train_data[0]
print(img.shape)
print(target)

writer = SummaryWriter("logs/dataloadlogs")
for epoh in range(2):
    step = 0
    for data in test_loader:
        imgs, targets = data
        #print(imgs.shape)
        #print(targets)
        writer.add_images("epoh {}".format(epoh), imgs, step)
        step = step+1
        print(step)

writer.close()

你可能感兴趣的:(深度学习,python,人工智能)