学习pytorch6 torchvision中的数据集使用

torchvision中的数据集使用

  • 1. torchvision中的数据集使用
    • 官网文档
    • 注意点1 totensor实例化不要忘记加括号
    • 注意点2 download可以一直保持为True
    • 代码
    • 执行结果
  • 2. DataLoader的使用

1. torchvision中的数据集使用

官网文档

注意左上角的版本

https://pytorch.org/vision/0.9/
学习pytorch6 torchvision中的数据集使用_第1张图片

注意点1 totensor实例化不要忘记加括号

totensor实例化不要忘记加括号,否则后面用数据集序列号的时候会报错
学习pytorch6 torchvision中的数据集使用_第2张图片

注意点2 download可以一直保持为True

download可以一直保持为True,下载一次后指定目录下有下载好的数据集,代码不会重复下载,也可以自己把下载好的数据集压缩包放到指定目录,代码会自动解压缩

代码

from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms

# 用法1
# 数据下载很慢的话 可以使用迅雷下载,属性里面可以看到迅雷是从多方下载的,速度比较快 https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
train_set = datasets.CIFAR10(root='./dataset', train=True, download=True)
test_set = datasets.CIFAR10(root='./dataset', train=False, download=True)
# 下载的数据集是图片类型,可以debug查看数据
print(test_set[0])  # __getitem__ return img, target
print(type(test_set[0]))
img, target = test_set[0]
print(target)
print(test_set.classes[target])
print(img)
# PIL 图片可以直接show函数展示
img.show()

# 用法2
# 将数据集批量调用transforms,使用tensor数据类型
# trans_compose = transforms.Compose([transforms.ToTensor])  # 错误写法 会导致后面报错
trans_compose = transforms.Compose([transforms.ToTensor()])
train_set2 = datasets.CIFAR10(root='./dataset', train=True, transform=trans_compose, download=True)
test_set2 = datasets.CIFAR10(root='./dataset', train=False, transform=trans_compose, download=True)
print(type(test_set2[2]))
img, target = test_set2[0]
print(target)
print(test_set2.classes[target])
print(type(img))
writer = SummaryWriter("logs")
for i in range(10):
    img_tensor, target = test_set2[i]
    writer.add_image('tensor dataset', img_tensor, i)
writer.close()

执行结果

> p11_torchvision_dataset.py
Files already downloaded and verified
Files already downloaded and verified
(<PIL.Image.Image image mode=RGB size=32x32 at 0x1CF47DA9E20>, 3)
<class 'tuple'>
3
cat
<PIL.Image.Image image mode=RGB size=32x32 at 0x1CF47DA9E20>
Files already downloaded and verified
Files already downloaded and verified
<class 'tuple'>
3
cat
<class 'torch.Tensor'>

Process finished with exit code 0

2. DataLoader的使用

你可能感兴趣的:(python,学习pytorch,学习,python,pytorch,数据集)