【pytorch】Dataset和Dataloader

文章目录

    • 一、torchvision中的数据集使用
    • 二、DataLoader的使用

一、torchvision中的数据集使用

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])


#加载内置数据集并配置转换参数transform=dataset_transforms
train_set = torchvision.datasets.CIFAR10(root=".\dataset",train=True,transform=dataset_transforms,download=True)
test_set = torchvision.datasets.CIFAR10(root=".\dataset",train=False,transform=dataset_transforms,download=True)

# print(test_set[0])
# print(test_set.classes)#查看测试机的类别(标签)
#
# img,target = test_set[0]
# print(img)
# print(target)
# print(test_set.classes[target])
# img.show()

writer = SummaryWriter("p10")
for i in range(10):
    img,target = test_set[i]
    writer.add_image("test_set",img,i)


writer.close()

二、DataLoader的使用

【pytorch】Dataset和Dataloader_第1张图片

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

test_data = torchvision.datasets.CIFAR10(".\dataset",train=False,transform=torchvision.transforms.ToTensor())

#num_workers:进程数

test_loader = DataLoader(dataset=test_data,batch_size=64,shuffle=True,num_workers=0,drop_last=False)


img,target = test_data[0]
print(img.shape)
print(target)

writer = SummaryWriter("dataloader")
step = 0
for data in test_loader:
    imgs,targets = data
    writer.add_images("imgs",imgs,step)
    step = step + 1


writer.close()

Dataloader常用参数

  • dataset:所要加载数据的数据集
  • batch_size:一次抓取的数据数量
  • shuffle:每次抓取前后是否“洗牌”,True:洗牌(打乱),False:不洗牌(不打乱)
  • num_workers:进程数量
  • drop_last:是否舍弃最后一次抓取不足batch_size的抓取的数据

你可能感兴趣的:(pytorch,深度学习,人工智能)