小土堆pytorch教程学习笔记P14

P14.torchvision中的数据集使用

Pytorch官网 -> Docs官方文档 -> torchvision视觉

torchvision.datasets 数据集(COCO、MINST、CIFAR10...)

torchvision.io 输入输出模块

torchvision.models 提供一些常见的神经网络(已预训练完毕)

torchvision.ops 一些特殊操作

torchvision.transforms

torchvision.utils 提供一些常见的小工具(TensorBoard来自这个模块)

CIFAR10数据集

class torchvision.datasets.CIFAR10(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)

Parameters

  • root (string) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.

  • train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.

  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

import torchvision

train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)

#下载数据集

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./dataset/cifar-10-python.tar.gz
100.0%Extracting ./dataset/cifar-10-python.tar.gz to ./dataset
Files already downloaded and verified

import torchvision

train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, download=True)

print(test_set[0])
print(test_set.classes)

img, target = test_set[0]
print(img)
print(target)
print(test_set.classes[target])
img.show()

Files already downloaded and verified
Files already downloaded and verified
(, 3)
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

3
cat

The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)

# print(test_set[0])
# print(test_set.classes)
#
# img, target = test_set[0]
# print(img)
# print(target)
# print(test_set.classes[target])
# img.show()
#
# print(test_set[0])

writer = SummaryWriter("P14")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_set", img, i)

writer.close()

tensorboard --logdir="P14"

你可能感兴趣的:(pytorch)