读取cifar-10数据集

读取cifar-10数据集

  • 1 数据集介绍
  • 2 读取方法
    • 2.1 方法一:TensorFlow读取数据集
    • 2.2 方法二:读取离线数据集

github链接:https://github.com/gdutthu/Statistical-learning-method
知乎专栏链接:https://zhuanlan.zhihu.com/c_1257792845504708608

因为在cs231n的作业中需要用到cifar 10数据集。在这里对读取该数据集的方法进行一些简单总结。
cifar-10数据集下载链接:http://www.cs.toronto.edu/~kriz/cifar.html

1 数据集介绍

由于数据集中每张图像为32x32,有RGB3个通道,按照RGB通道顺序以及每一通道按照行的顺序已排列好,一个训练样本对应一行有32x32x3=3072个值。
读取cifar-10数据集_第1张图片
其中batches.meta记录了数据集中十个类别的对应信息,data_batch_1到data_batch_5存在了训练样本,test_batch存放了测试样本,具体信息可查看上面的链接。

由于数据集中每张图像为32x32,有RGB3个通道,按照RGB通道顺序以及每一通道按照行的顺序已排列好,一个训练样本对应一行有32x32x3=3072个值。用以上的unpickle函数处理后会返回一个字典dict,dict中的key中有数据和标签,value为对应的值。如dict[b’data’]返回一个unit8的10000x3072维numpy的array。dict[b’labels’]返回一个长度为10000的list,list中的值为0-9,每个数字表示一个类别,具体对应类别可由batchs.meta获取。

2 读取方法

2.1 方法一:TensorFlow读取数据集

采用TensorFlow加载cifar 10数据集(推荐)
1、下载cifar 10数据集数据集(下载Python版本数据集)。下载链接:http://www.cs.toronto.edu/~kriz/cifar.html

2、修改文件名。将原文件名cifar-10-python.tar.gz改成cifar-10-batches-py.tar.gz

3、移动文件位置。将修改名字后的文件移动到 C:\Users{你的用户名}.keras\datasets

4、采用TensorFlow读取数据集

import  tensorflow as  tf
import numpy as np

#加载cifar-10数据集
def dataLoad():
    (train_data, train_label), (test_data, test_label) =  tf.keras.datasets.cifar10.load_data()
    #数据集进行归一化
    train_data=train_data/255
    test_data=test_data/255
    #将标签数据集从数组类型array修改成整形类型int
    train_label.astype(np.int)
    test_label.astype(np.int)
    return (train_data, train_label), (test_data, test_label)

2.2 方法二:读取离线数据集

1、下载cifar 10数据集数据集(下载Python版本数据集)。下载链接:http://www.cs.toronto.edu/~kriz/cifar.html

将下载好的数据集解压到相对路径中,文件名为cifar-10-batches-py

2、采用官网读取数据集的方法

数据集官网上提供了python3读取CIFAR-10的方式,以下函数可以将数据集转化为字典类型:

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

3、合并数据集

import numpy as np
import os

#创建训练样本和测试样本
def CreatData():
    #创建训练样本
    #依次加载batch_data_i,并合并到x,y
    x=[]
    y=[]
    for i in range(1,6):
        batch_path='cifar-10-batches-py\data_batch_%d'%(i)
        batch_dict=unpickle(batch_path)
        train_batch=batch_dict[b'data'].astype('float')
        train_labels=np.array(batch_dict[b'labels'])
        x.append(train_batch)
        y.append(train_labels)
    #将5个训练样本batch合并为50000x3072,标签合并为50000x1
    #np.concatenate默认axis=0,为纵向连接
    traindata=np.concatenate(x)
    trainlabels=np.concatenate(y)
    
    #创建测试样本
    #直接写cifar-10-batches-py\test_batch会报错,因此把/t当作制表符了,应用\\;
    #    test_dict=unpickle("cifar-10-batches-py\\test_batch")
    
    #建议使用os.path.join()函数
    testpath=os.path.join('cifar-10-batches-py','test_batch')
    test_dict=unpickle(testpath)
    testdata=test_dict[b'data'].astype('float')
    testlabels=np.array(test_dict[b'labels'])
    
    return traindata,trainlabels,testdata,testlabels

你可能感兴趣的:(cs231n)