Pytorch学习日记——DataLoader的使用

学习视频——B站【小土堆】

Pytorch学习日记——DataLoader的使用_第1张图片

下载数据集(将数据集下载到dataset文件夹,数据集为torchvision中的cifar-10)

import torchvision

train_set = torchvision.datasets.CIFAR10(root="./dataset", train = True,download=True)
test_set = torchvision.datasets.CIFAR10(root="dataset",train=False,download=True)

dataloader的使用,代码

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

#准备的测试数据集
test_data = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor())

test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0,drop_last=False)#最后一组不足64张时会舍去


#测试数据集中第一张图片target
img,target = test_data[0]
print(img.shape)
print(target)


writer = SummaryWriter("dataloader")
step=0
for data in test_loader:
    imgs, targets = data
    #print(imgs.shape)
    #print(targets)
    writer.add_images("test_data", imgs,step)
    step+=1

writer.close()

终端输入

tensorboard --logdir="dataloader"

结果

Pytorch学习日记——DataLoader的使用_第2张图片

你可能感兴趣的:(pytorch,学习,深度学习,python,人工智能)