【TensorFlow深度学习实战】VGG16实现CIFAR10数据集分类(上)

概要

为了实现VGG16网络对CIFAR10数据集的分类,我们首先得对CIFAR10进行一个详细介绍。并实现
本博客主要介绍Cifar10数据集的主要情况以及如何导入Cifar10数据集,并构造一个能类似于tensoflow中mnist数据集类,实现随机获取训练和测试小批量数据集。


Cifar10数据集说明

Cifar10数据集共有60000张彩色图像,这些图像是32*32,分为10个类,每类6000张图。其中,有50000张用于训练,构成了5个训练批,每一批10000张图;另外10000用于测试,单独构成一批。测试批的数据里,取自10类中的每一类,每一类随机取1000张。抽剩下的就随机排列组成了训练批。注意一个训练批中的各类图像并不一定数量相同,总的来看训练批,每一类都有5000张图。
下面这幅图就是列举了10各类,每一类展示了随机的10张图片:
【TensorFlow深度学习实战】VGG16实现CIFAR10数据集分类(上)_第1张图片
该数据是由以下三个人收集而来:Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton。第一位是AlexNet的提出者,第三位就更不用说了——深度学习的奠基人。

该数据集的下载网址为:http://www.cs.toronto.edu/~kriz/cifar.html 。这个数据主要有三个下载版本:Python、Matlab和二进制文件(适合于C语言)。由于我主要是利用tensorflow来实现VGG,因此我下载的是Python版本的数据集。从网站上可以看出,无论下载那个版本的数据集文件都不是挺大,足够学习跑跑程序用。
【TensorFlow深度学习实战】VGG16实现CIFAR10数据集分类(上)_第2张图片

Cifar10数据集的解压与保存

下面开始导入Cifar10数据集。将官网上下载的数据集打开之后,文件结构如下图所示。主要包含了5个data_batch文件data_batch_1至data_batch_5、1个test_batch文件和1个batches的meta文件。
【TensorFlow深度学习实战】VGG16实现CIFAR10数据集分类(上)_第3张图片
从Cifar10数据集官网上的介绍来看,5个data_batch文件和test_batch文件是利用pickel序列化之后的文件因此在导入 Cifar10数据集必须利用pickel进行解压数据,之后将数据还原。5个data_batch文件和test_batch文件分别代表5个训练集批次和测试集,因此我们首先利用pickel编写解压
函数:

def unpickle(data_path):
    """
    这是解压pickle数据的函数
    :param data_path: 数据路径
    """
    # 解压数据
    with open(data_path,'rb') as f:
        data_dict = pk.load(f,encoding='latin1')
        # 获取标签,形状(10000,)
        labels = np.array(data_dict['labels'])
        # 获取图像数据
        data = np.array(data_dict['data'])
        # 转换图像数据形状,形状为(10000,32,32,3)
        data = np.reshape(data,(10000,3,32,32)).transpose(0,2,3,1).astype("float")
    return data,labels

为了更好地构造数据集类,方便提取指定大小的小批量样本集,我们必须对整个数据进行一次接结构调整,test文件夹存测试图像集 ,train文件夹存训练图像集。图像命名为"分类_序号",如“apple_0”。代码如下:

# -*- coding: utf-8 -*-
# @Time    : 2019/6/1 10:30
# @Author  : DaiPuwei
# @Email   : [email protected]
# @Blog    : https://daipuweiai.blog.csdn.net/
# @File    : process_cifar.py
# @Software: PyCharm


import os
import cv2
import numpy as np
import pickle as pk
from Config import config as cfg

def unpickle(data_path):
    """
    这是解压pickle数据的函数
    :param data_path: 数据路径
    """
    # 解压数据
    with open(data_path,'rb') as f:
        data_dict = pk.load(f,encoding='latin1')
        # 获取标签,
        #labels = np.array(data_dict['fine_labels'])          # CIFAR100数据集
        labels = np.array(data_dict['labels'])                # CIFAR10数据集
        # 获取图像数据
        data = np.array(data_dict['data'])
        # 转换图像数据形状,形状为(10000,32,32,3)
        size = len(data)
        data = np.reshape(data,(size,3,32,32)).transpose(0,2,3,1).astype("float")
    return data,labels

def save_image(data_paths,label_names,save_path):
    """
    这是保存图像的函数
    :param data_paths: 数据集路径
    :param label_names: 标签名称
    :param save_path: 数据集保存路径
    """
    size = len(label_names)
    index_to_label = dict(zip(np.arange(size),label_names))
    label_to_cnt = dict(zip(label_names,[0]*size))
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    for path in data_paths:
        images,labels = unpickle(path)
        for image,label in zip(images,labels):
            image_name = index_to_label[label]       # label的分类
            cnt = label_to_cnt[image_name]           # 分类计数
            image_path = os.path.join(save_path,image_name+"_"+str(cnt)+".png")
            cv2.imwrite(image_path,image)       # 保存图像
            label_to_cnt[image_name] += 1       # 计数加1

def run_main():
    """
       这是主函数
    """
    dirpath = os.path.abspath("./cifar-10-bathches-py")
    dataset_path = "./"
    train_data_paths = ["train"]
    test_data_paths = ["test"]
    label_names = cfg.LABEL_NAMES
    for i,path in enumerate(train_data_paths):
        train_data_paths[i] = os.path.join(dirpath,path)
    for i,path in enumerate(test_data_paths):
        test_data_paths[i] = os.path.join(dirpath,path)
    print(np.shape(os.listdir("./train")))
    print(np.shape(os.listdir("./test")))
    save_image(train_data_paths,label_names,os.path.abspath(dataset_path+"train"))
    save_image(test_data_paths, label_names, os.path.abspath(dataset_path + "test"))

if __name__ == '__main__':
    run_main()

结果如下:
【TensorFlow深度学习实战】VGG16实现CIFAR10数据集分类(上)_第4张图片


数据集提取类的定义

之后主要我们要构造一个根据指定小批量样本集规模获取训练或测试集的数据类,这个可以参照tensorflow中为提取mnist数据集提供的接口**。在这个类中我们必须注意我们必须考虑到数据大小与CPU和GPU资源的关系,不能将数据一次性读入CPU/GPU缓存而导致资源溢出,因此我们必须实时读入小批量样本数据,那么我们必须用到python的yeild生成器。整个类的代码如下:

# -*- coding: utf-8 -*-
# @Time    : 2019/6/1 9:33
# @Author  : DaiPuwei
# @Email   : [email protected]
# @Blog    : https://daipuweiai.blog.csdn.net/
# @File    : DataSet.py
# @Software: PyCharm

import os
import cv2
import numpy as np
from Config import config as cfg

class DataSet(object):
    def __init__(self,one_hot = False):
        """
        这是数据集的初始化函数
        :param is_training: 是否训练阶段的标志
        :param one_hot: 标签是否采用one-hot编码格式
        """
        self.one_hot = one_hot
        self.train_path = os.path.join(cfg.DATASET_PATH,"train")            # 训练图像文件夹路径
        self.test_path = os.path.join(cfg.DATASET_PATH,"test")              # 测试图像文件夹路径
        self.train_lists = os.listdir(self.train_path)                      # 训练图像集名称列表
        self.test_lists = os.listdir(self.test_path)                        # 测试图像集名称列表
        self.train_dataset_size = len(self.train_lists)                     # 训练图像集规模
        self.test_dataset_size = len(self.test_lists)                       # 测试图像集规模
        self.class_num = cfg.CLASS_NUM              # 类别个数
        self.label_names = cfg.LABEL_NAMES          # 类别名称列表
        self.label_to_index = dict(zip(self.label_names,np.arange(self.class_num))) # 标签名称与序号之间的对应关系
        self.epoch = 1          # 用于训练网络是对应的周期数

    def next_batch(self,lists,dirpath,batch_size):
        """
        这是获取一个小批量数据集的函数
        :param lists: 图像数据集
        :param dirpath: 图像数据集文件夹路径
        :param batch_size: 小批量规模
        """
        start = 0
        size = len(lists)
        # 数据集被遍历一遍之后,立即返回初始位置,并完成随机打乱数据集集
        while start < size:
            if start >= size:
                start = 0
                self.epoch += 1
                np.random.shuffle(lists)
            end = np.min([start + batch_size, size])
            batch_image_paths = lists[start:end]
            yield self.get_batch(dirpath,batch_image_paths)         # 返回生成器
            start = start + batch_size

    def get_batch(self,dirpath,batch_image_paths):
        """
        这是读取一个小批量数据集的函数
        :param dirpath: 文件夹路径
        :param batch_image_paths: 小批量图像数据集的路径
        """
        images = []         # 训练图片集
        labels = []         # 训练标签集
        for path in batch_image_paths:
            # 获取标签
            label_name = path.split(".")[0]     # 获取标签名称
            if self.one_hot == True:
                _label = [0]*self.class_num
                _label[self.label_to_index[label_name]] = 1
                labels.append(_label)
            else:
                labels.append(self.label_to_index[label_name])
            # 读取图片
            image_path = os.path.abspath(os.path.join(dirpath, path))
            image = cv2.imread(image_path)
            images.append(image)
        return images,np.array(labels,dtype=np.float32)

    def Train_next_batch(self,batch_size):
        """
        这是获取一个小批量训练数据集的函数
        :param batch_size: 小批量规模
        """
        next_batch = self.next_batch(self.train_lists,self.train_path,batch_size)
        return next_batch.__next__()

    def Test_next_batch(self,batch_size):
        """
        这是获取一个小批量测试数据集的函数
        :param batch_size: 小批量规模
        """
        next_batch = self.next_batch(self.test_lists,self.test_path,batch_size)
        return next_batch.__next__()

你可能感兴趣的:(深度学习与计算机视觉,深度学习与计算机视觉)