【Pytorch】--CIFAR10等数据集本地读取

文章目录

      • 1.问题
      • 2.解决办法
      • 3.显示CIFAR10的图像

1.问题

\quad 在使用Pytorch的时候,有时候需要在线下载数据集,因为在下载的过程中,封装好的代码,还要进行其他的操作(例如数据类型转换numpy->tensor),但是有时候因为下载网站在国外,进度条一直显示0%,
\quad 就像这样:
在这里插入图片描述

2.解决办法

  • step1.下载数据集到本地
    【Pytorch】--CIFAR10等数据集本地读取_第1张图片

  • step2. 将本地存放CIFAR数据集路径放在浏览器下,回车
    【Pytorch】--CIFAR10等数据集本地读取_第2张图片

  • step3. 修改class CIFAR10(VisionDataset)中的url
    在这里插入图片描述

  • step4. 运行代码

    import torch
    import torchvision
    LOAD_CIFAR = True 
    DOWNLOAD_CIFAR = True
    
    train_data = torchvision.datasets.CIFAR10(
        root='./cifar10/', 
        train=True, 
        transform=torchvision.transforms.ToTensor(), 
        download=DOWNLOAD_CIFAR,
    )
    

    一切ok
    在这里插入图片描述

3.显示CIFAR10的图像

import torch
import torchvision
import matplotlib.pyplot as plt

EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_CIFAR = False

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

train_data = torchvision.datasets.CIFAR10(
    root='./cifar10/',  # 保存或者提取位置
    train=True,  # this is training data
    transform=torchvision.transforms.ToTensor(),  # 转换 PIL.Image or numpy.ndarray 成
    # torch.FloatTensor (C x H x W), 训练的时候 normalize 成 [0.0, 1.0] 区间
    download=DOWNLOAD_CIFAR,  # 没下载就下载, 下载了就不用再下了
)

# method 1 
# dataiter = iter(train_data)
# plt.show()
# for _ in range(len(train_data)):
#     images, labels = dataiter.__next__()
#     images = images.numpy().transpose(1, 2, 0)  # 把channel那一维放到最后
#     plt.title(str(classes[labels]))
#     plt.imshow(images)
#     plt.pause(1)

# as 2 list
# method 2
plt.show()
for images, labels in train_data:
    images = images.numpy().transpose(1, 2, 0)  # 把channel那一维放到最后
    plt.title(str(classes[labels]))
    plt.imshow(images)
    plt.pause(1)

显示出来的图像
【Pytorch】--CIFAR10等数据集本地读取_第3张图片【Pytorch】--CIFAR10等数据集本地读取_第4张图片

你可能感兴趣的:(Pytorch)