[Pytorch系列-33]:数据集 - torchvision与CIFAR10/CIFAR100详解

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

 本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055970


目录

第1章 TorchVision概述

1.1 TorchVision

1.2 TorchVision的安装

1.3 TorchVision官网的数据集

1.4 TorchVision常见的数据集概述

第2章 CIFAR10数据集

2.1 数据集概述

2.2 与 MNIST 数据集比较

2.3 下载地址

第3章 TorchVision对CIFAR10的支持

3.1 函数原型

3.2 数据下载前的准备

3.3 数据集下载与导入

3.4 显示单张样本图片

3.5 启动loader对象

3.6 显示批量图片

第4章 CIFAR100与CIFAR10的比较

4.1 相同点

4.2 不同点

第5章 图片集的手工下载

5.1 CIFAR10

5.2 CIFAR100



第1章 TorchVision概述

1.1 TorchVision

Pytorch非常有用的工具集:

  • torchtext:处理自然语言
  • torchaudio:处理音频的
  • torchvision:处理图像视频的。

torchvision包含一些常用的数据集、模型、转换函数等等。本文重点放在torchvision的数据集上。

1.2 TorchVision的安装

pip install torchvision 

1.3 TorchVision官网的数据集

https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/

1.4 TorchVision常见的数据集概述

  • MNIST
  • CIFAR10
  • CIFAR100
  • COCO(用于图像标注和目标检测)(Captioning and Detection)
  • LSUN Classification
  • ImageNet flowers
  • Imagenet-12
  • STL10

第2章 CIFAR10数据集

2.1 数据集概述

CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。

该数据集共有60000张彩色图像,这些图像是32*32,分为10个类RGB 彩色三通道图 片,每类6000张图。

其中,50000张用于训练,构成了5个训练批次,每一批10000张图;

其中,10000张用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批次。

注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。

CIFAR-10 的图片样例如图所示,包括

飞机( a叩lane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。
å¨è¿éæå¥å¾çæè¿°

2.2 与 MNIST 数据集比较

与 MNIST 数据集比较, CIFAR-10 具有以下不同点:

  • CIFAR-10 是 3 通道的彩色 RGB 图像,而 MNIST 是灰度图像。
  • CIFAR-10 的图片尺寸为 32×32, 而 MNIST 的图片尺寸为 28×28,比 MNIST 稍大。
  • 相比于手写字符, CIFAR-10 含有的是现实世界中真实的物体,不仅噪声很大,而且物体的比例、 特征都不尽相同,这为识别带来很大困难。
  • 直接的全连接的线性模型,即使在MNIST表现良好,在 CIFAR-10数据集上表现得很差。

2.3 下载地址

官方下载地址:(很慢)

一共有三个版本:python,matlab,binary version 适用于C语言

http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

http://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz

http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz

第3章 TorchVision对CIFAR10的支持

3.1 函数原型

CIFAR10 (root, train=True, transform=None, target_transform=None, download=False)

  • root:存储数据集的根目录
  • train=True or false:训练集还是测试集
  • transform=None:在加载数据前的格式转换
  • target_transform=None:
  • download=False:是否需要在线下载

3.2 数据下载前的准备

#环境准备
import numpy as np              # numpy数组库
import math                     # 数学运算库
import matplotlib.pyplot as plt # 画图库

import torch             # torch基础库
import torchvision.datasets as dataset  #公开数据集的下载和管理
import torchvision.transforms as transforms  #公开数据集的预处理库,格式转换
import torchvision.utils as utils 
import torch.utils.data as data_utils  #对数据集进行分批加载的工具集

print("Hello World")
print(torch.__version__)
print(torch.cuda.is_available())

3.3 数据集下载与导入

如果本地没有数据集,会自动远程下载

#2-1 准备数据集
train_data = dataset.CIFAR10 (root = "cifar10",
                           train = True,
                           transform = transforms.ToTensor(),
                           download = True)

#2-1 准备数据集
test_data = dataset.MNIST(root = "cifar10",
                           train = False,
                           transform = transforms.ToTensor(),
                           download = True)

print(train_data)
print("size=", len(train_data))
print("")
print(test_data)
print("size=", len(test_data))
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to cifar10\cifar-10-python.tar.gz
Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to cifar10\cifar-10-python.tar.gz
100.0%
Extracting cifar10\cifar-10-python.tar.gz to cifar10
1.1%
Downloading http://183.207.33.38:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/train-images-idx3-ubyte.gz to cifar10\MNIST\raw\train-images-idx3-ubyte.gz
100.0%
Extracting cifar10\MNIST\raw\train-images-idx3-ubyte.gz to cifar10\MNIST\raw
102.8%
Downloading http://183.207.33.42:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/train-labels-idx1-ubyte.gz to cifar10\MNIST\raw\train-labels-idx1-ubyte.gz
Extracting cifar10\MNIST\raw\train-labels-idx1-ubyte.gz to cifar10\MNIST\raw
5.0%
Downloading http://183.207.33.38:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/t10k-images-idx3-ubyte.gz to cifar10\MNIST\raw\t10k-images-idx3-ubyte.gz
100.0%
Extracting cifar10\MNIST\raw\t10k-images-idx3-ubyte.gz to cifar10\MNIST\raw
Downloading http://183.207.33.42:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/t10k-labels-idx1-ubyte.gz to cifar10\MNIST\raw\t10k-labels-idx1-ubyte.gz
112.7%
Extracting cifar10\MNIST\raw\t10k-labels-idx1-ubyte.gz to cifar10\MNIST\raw
Processing...
C:\ProgramData\Anaconda3\envs\pytorch1.8_py3.8\lib\site-packages\torchvision\datasets\mnist.py:479: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ..\torch\csrc\utils\tensor_numpy.cpp:143.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
Done!
Dataset CIFAR10
    Number of datapoints: 50000
    Root location: cifar10
    Split: Train
    StandardTransform
Transform: ToTensor()
size= 50000

Dataset MNIST
    Number of datapoints: 10000
    Root location: cifar10
    Split: Test
    StandardTransform
Transform: ToTensor()
size= 10000

3.4 显示单张样本图片

#原图不叠加噪声
#获取一张图片数据
print("原始Pytorch图片")
image, label = train_data[2]
print("torch image shape:", image.shape)
print("torch image label:", label)

print("\n通道转换后的Numpy图片")
image = image.numpy().transpose(1,2,0)  #交换维度,从GBR换成RGB
print("numpy image shape:", image.shape)
print("numpy image label:", label)

plt.imshow(image)
plt.show()

3.5 启动loader对象

# 批量数据读取
train_loader = data_utils.DataLoader(dataset = train_data,
                                  batch_size = 8,
                                  shuffle = True)

test_loader = data_utils.DataLoader(dataset = test_data,
                                  batch_size = 8,
                                  shuffle = True)

print(train_loader)
print(test_loader)
print(len(train_data), len(train_data)/8)
print(len(test_data),  len(test_data)/8)


50000 6250.0
10000 1250.0

3.6 显示批量图片

pytorch对图片的格式定义与Numpy对图片的格式定义是不一样的。

因此需要通过transpose()进行维度的变换。

#显示一个batch图片
print("获取一个batch组图片")
imgs, labels = next(iter(train_loader))
print(imgs.shape)
print(labels.shape)
print(labels.size()[0])

print("\n合并成一张三通道灰度图片")
images = utils.make_grid(imgs, nrow = 4)
print(images.shape)
print(labels.shape)

print("\n转换成imshow格式")
images = images.numpy().transpose(1,2,0) 
print(images.shape)
print(labels.shape)

print("\n显示图片")
plt.imshow(images)
plt.show()
获取一个batch组图片
torch.Size([8, 3, 32, 32])
torch.Size([8])
8

合并成一张三通道灰度图片
torch.Size([3, 70, 138])
torch.Size([8])

转换成imshow格式
(70, 138, 3)
torch.Size([8])

显示图片

[Pytorch系列-33]:数据集 - torchvision与CIFAR10/CIFAR100详解_第1张图片

第4章 CIFAR100与CIFAR10的比较

4.1 相同点

采用相同的图片布局:3 * 32 * 32 = 3072

4.2 不同点

  • 有100个类,每个类包含600个图像。
  • 每类各有500个训练图像和100个测试图像。
  • CIFAR-100中的100个类被分成20个超类。
  • 每个图像都带有一个“精细”标签(它所属的类)和一个“粗糙”标签(它所属的超类

以下是CIFAR-100中的20个超类别以及对应的子类:

超类 类别
水生哺乳动物 海狸,海豚,水獭,海豹,鲸鱼
水族馆的鱼,比目鱼,射线,鲨鱼,鳟鱼
花卉 兰花,罂粟花,玫瑰,向日葵,郁金香
食品容器 瓶子,碗,罐子,杯子,盘子
水果和蔬菜 苹果,蘑菇,橘子,梨,甜椒
家用电器 时钟,电脑键盘,台灯,电话机,电视机
家用家具 床,椅子,沙发,桌子,衣柜
昆虫 蜜蜂,甲虫,蝴蝶,毛虫,蟑螂
大型食肉动物 熊,豹,狮子,老虎,狼
大型人造户外用品 桥,城堡,房子,路,摩天大楼
大自然的户外场景 云,森林,山,平原,海
大杂食动物和食草动物 骆驼,牛,黑猩猩,大象,袋鼠
中型哺乳动物 狐狸,豪猪,负鼠,浣熊,臭鼬
非昆虫无脊椎动物 螃蟹,龙虾,蜗牛,蜘蛛,蠕虫
宝贝,男孩,女孩,男人,女人
爬行动物 鳄鱼,恐龙,蜥蜴,蛇,乌龟
小型哺乳动物 仓鼠,老鼠,兔子,母老虎,松鼠
树木 枫树,橡树,棕榈,松树,柳树
车辆1 自行车,公共汽车,摩托车,皮卡车,火车
车辆2 割草机,火箭,有轨电车,坦克,拖拉机

第5章 图片集的手工下载

5.1 CIFAR10

CIFAR-10 python版本
CIFAR-10 Matlab版本
CIFAR-10二进制版本(适用于C程序)

5.2 CIFAR100

CIFAR-100 python版本
CIFAR-100 Matlab版本
CIFAR-100二进制版本(适用于C程序)


作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

 本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055970

你可能感兴趣的:(人工智能-深度学习,人工智能-PyTorch,pytorch,深度学习,人工智能,数据集,CIFAR10)