torchvision中数据集的使用

当我们想要使用torchvision中自带的数据集时,应该怎么做呢?

1 导包

import torchvision

2 下载

train_set = torchvision.datasets.CIFAR10(root="../dataset",train=True,download=True)
test_set = torchvision.datasets.CIFAR10(root="../dataset",train=False,download=True)
参数解释:

root:数据集要存放的地址
train:值为True时下载训练集,值为False时下载测试集
download:一般均设置为True

3 使用

print(test_set[0])#打印test数据集的第一张图片的所有信息
print(test_set.classes)#打印数据集所有类别信息
img ,target = test_set[0]
print(img)#打印test数据集第一张图片的图片信息
print(target)#打印test数据集第一张图片的类别信息
img.show()#显示图片(因为这张图片的格式为PIL,故可以直接.show())
使用结果

—————————————————————————————————————————

4 扩展

和上节课的tensorboard、transforms等知识相结合后的代码及结果如下:

完整代码:

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

train_set = torchvision.datasets.CIFAR10(root="../dataset",train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="../dataset",train=False,transform=dataset_transform,download=True)

# print(test_set[0])
# print(test_set.classes)
#
# img ,target = test_set[0]
# print(img)
# print(target)
# img.show()

print(test_set[0])

writer = SummaryWriter("../logs/P10_logs")
for i in range(10):
    img , target = test_set[i]
    writer.add_image("test_set",img ,i)

writer.close()

结果:


torchvision

你可能感兴趣的:(torchvision中数据集的使用)