PyTorch深度学习——DataLoader使用

dataloader的使用,相关参数用法可以参考官方文档说明:
dataloader

PyTorch深度学习——DataLoader使用_第1张图片

import torchvision

#准备测试数据
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

test_data = torchvision.datasets.CIFAR10("./dataset",train=True,transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=True)

#batch_size (int, optional) – how many samples per batch to load (default: 1).
#shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
#num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
#drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)


#测试数据集的第一张图片
img,target = test_data[0]
print(img.shape)
print(target)

writer = SummaryWriter("dataloader")
for epoch in range(2):     #两轮读取
    step = 0
    for data in test_data:
        imgs,targets = data
        writer.add_image("Epoch:{}".format(epoch),imgs,step)
        step = step+1

writer.close()

PyTorch深度学习——DataLoader使用_第2张图片

你可能感兴趣的:(PyTorch)