目录
需要导入
dataset模板
dataset重写例子
使用方法
dataloader使用
Dataset如果是一叠扑克牌的话,DataLoader就是一只手,参数就是告诉这只手怎么抓取扑克牌。
DataLoader常用参数:
dataset (Dataset) – 使用哪个数据集
batch_size (int, optional) – 一次抓取多少个数据,多少张牌
shuffle (bool, optional) – 是否重新洗牌
num_workers (int, optional) – 多线程,windows下要设成0,不然会出错,出现
BrokenPipeError: [Errno 32] Broken pipe
drop_last (bool, optional) – 当数据集大小不能被批大小整除时,设置为True则以删除最后一个不完整的批。False,则不删除,最后一批将变小。(默认值:False)
datasetloader加载批数据过程:
Dataset(torch.utils.data.Dataset)
class mydataset(Dataset):
def __init__(self, xxx):
...
def __getitem__(self, item):
...
def __len()__(self):
...
from torch.utils.data import Dataset, DataLoader
import torch
class MyDataset(Dataset):
"""
下载数据、初始化数据,都可以在这里完成
"""
def __init__(self,file):
self.x = torch.linspace(11,20,10)
self.y = torch.linspace(1,10,10)
self.len = len(self.x)
def __getitem__(self, index):
return self.x[index], self.y[index]
def __len__(self):
return self.len
dataset=mydataset(文件)
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
train_data = torchvision.datasets.CIFAR10("dataset", train=True, transform=torchvision.transforms.ToTensor())
test_data = torchvision.datasets.CIFAR10("dataset", train=False, transform=torchvision.transforms.ToTensor())
trian_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
img, target = train_data[0]
print(img.shape)
print(target)
writer = SummaryWriter("logs/dataloadlogs")
for epoh in range(2):
step = 0
for data in test_loader:
imgs, targets = data
#print(imgs.shape)
#print(targets)
writer.add_images("epoh {}".format(epoh), imgs, step)
step = step+1
print(step)
writer.close()