[ 数据集 ] CIFAR-10 数据集介绍


Author :Horizon Max

编程技巧篇:各种操作小结

机器视觉篇:会变魔术 OpenCV

深度学习篇:简单入门 PyTorch

神经网络篇:经典网络模型

算法篇:再忙也别忘了 LeetCode


文章目录

  • CIFAR-10
  • 数据集读取
  • Pytorch 读取数据集
    • torchvision.datasets.CIFAR10()
    • torchvision.datasets.ImageFolder()

CIFAR-10

Size: 32×32 RGB图像 ,数据集本身是 BGR 通道
Num: 训练集 50000 和 测试集 10000,一共60000张图片
Classes: plane(飞机), car(汽车),bird(鸟),cat(猫),deer(鹿),dog(狗),frog(蛙类),horse(马),ship(船),truck(卡车)

[ 数据集 ] CIFAR-10 数据集介绍_第1张图片


官方下载链接:CIFAR-10


数据集读取

官网提供了数据集读取的方法(python3 version):

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

返回的是一个 字典

1)CIFAR-10数据集文件夹

[ 数据集 ] CIFAR-10 数据集介绍_第2张图片

2)将字典打印出来

dict = unpickle('./data_batch_1')
print(dict)
{b'batch_label': b'training batch 1 of 5', 
 b'labels': [6, 9 ... 1, 5], 
 b'data': array([[ 59,  43,  50, ..., 140,  84,  72],
                                 ...
                 [ 62,  61,  60, ..., 130, 130, 131]], dtype=uint8),
 b'filenames': [b'leptodactylus_pentadactylus_s_000004.png', b'camion_s_000148.png',
                                 ...
                b'estate_car_s_001433.png', b'cur_s_000170.png']}

b’batch_label’ : 所属文件集
b’labels’ : 图片标签
b’data’ :图片数据
b’filename’ :图片名称

3)打印类型

print(type(dict[b'batch_label']))
print(type(dict[b'labels']))
print(type(dict[b'data']))
print(type(dict[b'filenames']))
<class 'bytes'>
<class 'list'>
<class 'numpy.ndarray'>
<class 'list'>

4)打印图片类型

img = dict[b'data']
print(img.shape)
(10000, 3072)

其中 3072 = 32 * 32 * 3 (图片 size)

5)绘制图片

show_image = img[666]
img_reshape = show_image.reshape(3, 32, 32)
pic = img_reshape.transpose(1, 2, 0)    # (3, 32, 32) --> (32, 32, 3)
plt.imshow(pic)
plt.show()

label = dict[b'labels']
image_label = label[666]
print(image_label)

[ 数据集 ] CIFAR-10 数据集介绍_第3张图片
在这里插入图片描述

9

这张图片是个 9 truck(卡车)


Pytorch 读取数据集

torchvision.datasets.CIFAR10()

transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 是基于ImageNet 数据集得到的最佳归一化方案 ;
transforms 可以对图片进行裁剪、旋转等操作 ;

torchvision.datasets.CIFAR10() 未下载download=True 即可,已下载更改数据集地址 root='./dataset' download=False

# 数据集 类别
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')
batch_size=64

# 训练集 数据归一化
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# 测试集 数据归一化
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# 训练集
trainset = torchvision.datasets.CIFAR10(
    root='./dataset', train=True, download=False, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=0)

# 测试集
testset = torchvision.datasets.CIFAR10(
    root='./dataset', train=False, download=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size, shuffle=True, num_workers=0)

torchvision.datasets.ImageFolder()

torchvision.datasets.ImageFolder() 使用这个函数需要将读取出来保存到 traintest 两个文件夹当中,并进行每一类的分类:

[ 数据集 ] CIFAR-10 数据集介绍_第4张图片
[ 数据集 ] CIFAR-10 数据集介绍_第5张图片
依此表示10个类别;

制作数据集:

import torch
import torchvision
import numpy as np
import cv2 as cv
from cleverhans.torch.attacks.fast_gradient_method import fast_gradient_method
from cleverhans.torch.attacks.projected_gradient_descent import projected_gradient_descent

batch_size = 50

transform_predict = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])
image_data = torchvision.datasets.CIFAR10(
    root='../cifar10', train=False, download=False, transform=transform_predict)
image_loader = torch.utils.data.DataLoader(
    image_data, batch_size, shuffle=True, num_workers=0)

path = './dataset/train/'


def format(image):
    image = image.clone().detach().cpu().squeeze(0)
    image = np.around(image.mul(255))
    image = np.uint8(image).transpose(1, 2, 0)
    return image


def data(model):
    idx0 = 0
    idx1 = 0
    idx2 = 0
    idx3 = 0
    idx4 = 0
    idx5 = 0
    idx6 = 0
    idx7 = 0
    idx8 = 0
    idx9 = 0

    for i, (data, target) in enumerate(image_loader):
        print(i)
        example = data.to(device)

        for idx in range(len(example)):
            label = target[idx].item()
            image = format(example[idx])

            if label == 0:
                cv.imwrite(str(path + '0/plane{}.png').format(str(idx0)), image)
                idx0 += 1

            if label == 1:
                cv.imwrite(str(path + '1/car{}.png').format(str(idx1)), image)
                idx1 += 1

            if label == 2:
                cv.imwrite(str(path + '2/bird{}.png').format(str(idx2)), image)
                idx2 += 1

            if label == 3:
                cv.imwrite(str(path + '3/cat{}.png').format(str(idx3)), image)
                idx3 += 1

            if label == 4:
                cv.imwrite(str(path + '4/deer{}.png').format(str(idx4)), image)
                idx4 += 1

            if label == 5:
                cv.imwrite(str(path + '5/dog{}.png').format(str(idx5)), image)
                idx5 += 1

            if label == 6:
                cv.imwrite(str(path + '6/frog{}.png').format(str(idx6)), image)
                idx6 += 1

            if label == 7:
                cv.imwrite(str(path + '7/horse{}.png').format(str(idx7)), image)
                idx7 += 1

            if label == 8:
                cv.imwrite(str(path + '8/ship{}.png').format(str(idx8)), image)
                idx8 += 1

            if label == 9:
                cv.imwrite(str(path + '9/truck{}.png').format(str(idx9)), image)
                idx9 += 1

data(model)

数据集读取:

# 数据集 类别
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')
batch_size=64

# 训练集 数据归一化
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# 测试集 数据归一化
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# 训练集
train_set = torchvision.datasets.ImageFolder(
    './dataset/train/', transform=transform_train)
train_loader = torch.utils.data.DataLoader(
    dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=0)

# 测试集
test_set = torchvision.datasets.ImageFolder(
    './dataset/test/', transform=transform_test)
test_loader = torch.utils.data.DataLoader(
    dataset=test_set, batch_size=batch_size, shuffle=True, num_workers=0)


你可能感兴趣的:(经典网络模型,人工智能,深度学习,数据集,CIFAR-10)