作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客
本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055970
目录
第1章 TorchVision概述
1.1 TorchVision
1.2 TorchVision的安装
1.3 TorchVision官网的数据集
1.4 TorchVision常见的数据集概述
第2章 CIFAR10数据集
2.1 数据集概述
2.2 与 MNIST 数据集比较
2.3 下载地址
第3章 TorchVision对CIFAR10的支持
3.1 函数原型
3.2 数据下载前的准备
3.3 数据集下载与导入
3.4 显示单张样本图片
3.5 启动loader对象
3.6 显示批量图片
第4章 CIFAR100与CIFAR10的比较
4.1 相同点
4.2 不同点
第5章 图片集的手工下载
5.1 CIFAR10
5.2 CIFAR100
Pytorch非常有用的工具集:
torchvision包含一些常用的数据集、模型、转换函数等等。本文重点放在torchvision的数据集上。
pip install torchvision
https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/
CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。
该数据集共有60000张彩色图像,这些图像是32*32,分为10个类RGB 彩色三通道图 片,每类6000张图。
其中,50000张用于训练,构成了5个训练批次,每一批10000张图;
其中,10000张用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批次。
注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。
CIFAR-10 的图片样例如图所示,包括
飞机( a叩lane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。
与 MNIST 数据集比较, CIFAR-10 具有以下不同点:
官方下载地址:(很慢)
一共有三个版本:python,matlab,binary version 适用于C语言
http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
http://www.cs.toronto.edu/~kriz/cifar-10-matlab.tar.gz
http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
CIFAR10 (root, train=True, transform=None, target_transform=None, download=False)
#环境准备
import numpy as np # numpy数组库
import math # 数学运算库
import matplotlib.pyplot as plt # 画图库
import torch # torch基础库
import torchvision.datasets as dataset #公开数据集的下载和管理
import torchvision.transforms as transforms #公开数据集的预处理库,格式转换
import torchvision.utils as utils
import torch.utils.data as data_utils #对数据集进行分批加载的工具集
print("Hello World")
print(torch.__version__)
print(torch.cuda.is_available())
如果本地没有数据集,会自动远程下载
#2-1 准备数据集
train_data = dataset.CIFAR10 (root = "cifar10",
train = True,
transform = transforms.ToTensor(),
download = True)
#2-1 准备数据集
test_data = dataset.MNIST(root = "cifar10",
train = False,
transform = transforms.ToTensor(),
download = True)
print(train_data)
print("size=", len(train_data))
print("")
print(test_data)
print("size=", len(test_data))
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to cifar10\cifar-10-python.tar.gz
Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to cifar10\cifar-10-python.tar.gz
100.0%
Extracting cifar10\cifar-10-python.tar.gz to cifar10
1.1%
Downloading http://183.207.33.38:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/train-images-idx3-ubyte.gz to cifar10\MNIST\raw\train-images-idx3-ubyte.gz
100.0%
Extracting cifar10\MNIST\raw\train-images-idx3-ubyte.gz to cifar10\MNIST\raw
102.8%
Downloading http://183.207.33.42:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/train-labels-idx1-ubyte.gz to cifar10\MNIST\raw\train-labels-idx1-ubyte.gz Extracting cifar10\MNIST\raw\train-labels-idx1-ubyte.gz to cifar10\MNIST\raw
5.0%
Downloading http://183.207.33.38:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/t10k-images-idx3-ubyte.gz to cifar10\MNIST\raw\t10k-images-idx3-ubyte.gz
100.0%
Extracting cifar10\MNIST\raw\t10k-images-idx3-ubyte.gz to cifar10\MNIST\raw Downloading http://183.207.33.42:9011/yann.lecun.com/c3pr90ntc0td/exdb/mnist/t10k-labels-idx1-ubyte.gz to cifar10\MNIST\raw\t10k-labels-idx1-ubyte.gz
112.7%
Extracting cifar10\MNIST\raw\t10k-labels-idx1-ubyte.gz to cifar10\MNIST\raw Processing...
C:\ProgramData\Anaconda3\envs\pytorch1.8_py3.8\lib\site-packages\torchvision\datasets\mnist.py:479: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ..\torch\csrc\utils\tensor_numpy.cpp:143.) return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
Done! Dataset CIFAR10 Number of datapoints: 50000 Root location: cifar10 Split: Train StandardTransform Transform: ToTensor() size= 50000 Dataset MNIST Number of datapoints: 10000 Root location: cifar10 Split: Test StandardTransform Transform: ToTensor() size= 10000
#原图不叠加噪声
#获取一张图片数据
print("原始Pytorch图片")
image, label = train_data[2]
print("torch image shape:", image.shape)
print("torch image label:", label)
print("\n通道转换后的Numpy图片")
image = image.numpy().transpose(1,2,0) #交换维度,从GBR换成RGB
print("numpy image shape:", image.shape)
print("numpy image label:", label)
plt.imshow(image)
plt.show()
# 批量数据读取
train_loader = data_utils.DataLoader(dataset = train_data,
batch_size = 8,
shuffle = True)
test_loader = data_utils.DataLoader(dataset = test_data,
batch_size = 8,
shuffle = True)
print(train_loader)
print(test_loader)
print(len(train_data), len(train_data)/8)
print(len(test_data), len(test_data)/8)
50000 6250.0 10000 1250.0
pytorch对图片的格式定义与Numpy对图片的格式定义是不一样的。
因此需要通过transpose()进行维度的变换。
#显示一个batch图片
print("获取一个batch组图片")
imgs, labels = next(iter(train_loader))
print(imgs.shape)
print(labels.shape)
print(labels.size()[0])
print("\n合并成一张三通道灰度图片")
images = utils.make_grid(imgs, nrow = 4)
print(images.shape)
print(labels.shape)
print("\n转换成imshow格式")
images = images.numpy().transpose(1,2,0)
print(images.shape)
print(labels.shape)
print("\n显示图片")
plt.imshow(images)
plt.show()
获取一个batch组图片 torch.Size([8, 3, 32, 32]) torch.Size([8]) 8 合并成一张三通道灰度图片 torch.Size([3, 70, 138]) torch.Size([8]) 转换成imshow格式 (70, 138, 3) torch.Size([8]) 显示图片
采用相同的图片布局:3 * 32 * 32 = 3072
以下是CIFAR-100中的20个超类别以及对应的子类:
超类 | 类别 |
---|---|
水生哺乳动物 | 海狸,海豚,水獭,海豹,鲸鱼 |
鱼 | 水族馆的鱼,比目鱼,射线,鲨鱼,鳟鱼 |
花卉 | 兰花,罂粟花,玫瑰,向日葵,郁金香 |
食品容器 | 瓶子,碗,罐子,杯子,盘子 |
水果和蔬菜 | 苹果,蘑菇,橘子,梨,甜椒 |
家用电器 | 时钟,电脑键盘,台灯,电话机,电视机 |
家用家具 | 床,椅子,沙发,桌子,衣柜 |
昆虫 | 蜜蜂,甲虫,蝴蝶,毛虫,蟑螂 |
大型食肉动物 | 熊,豹,狮子,老虎,狼 |
大型人造户外用品 | 桥,城堡,房子,路,摩天大楼 |
大自然的户外场景 | 云,森林,山,平原,海 |
大杂食动物和食草动物 | 骆驼,牛,黑猩猩,大象,袋鼠 |
中型哺乳动物 | 狐狸,豪猪,负鼠,浣熊,臭鼬 |
非昆虫无脊椎动物 | 螃蟹,龙虾,蜗牛,蜘蛛,蠕虫 |
人 | 宝贝,男孩,女孩,男人,女人 |
爬行动物 | 鳄鱼,恐龙,蜥蜴,蛇,乌龟 |
小型哺乳动物 | 仓鼠,老鼠,兔子,母老虎,松鼠 |
树木 | 枫树,橡树,棕榈,松树,柳树 |
车辆1 | 自行车,公共汽车,摩托车,皮卡车,火车 |
车辆2 | 割草机,火箭,有轨电车,坦克,拖拉机 |
CIFAR-10 python版本
CIFAR-10 Matlab版本
CIFAR-10二进制版本(适用于C程序)
CIFAR-100 python版本
CIFAR-100 Matlab版本
CIFAR-100二进制版本(适用于C程序)
作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客
本文网址:https://blog.csdn.net/HiWangWenBing/article/details/121055970