github链接:https://github.com/gdutthu/Statistical-learning-method
知乎专栏链接:https://zhuanlan.zhihu.com/c_1257792845504708608
因为在cs231n的作业中需要用到cifar 10数据集。在这里对读取该数据集的方法进行一些简单总结。
cifar-10数据集下载链接:http://www.cs.toronto.edu/~kriz/cifar.html
由于数据集中每张图像为32x32,有RGB3个通道,按照RGB通道顺序以及每一通道按照行的顺序已排列好,一个训练样本对应一行有32x32x3=3072个值。
其中batches.meta记录了数据集中十个类别的对应信息,data_batch_1到data_batch_5存在了训练样本,test_batch存放了测试样本,具体信息可查看上面的链接。
由于数据集中每张图像为32x32,有RGB3个通道,按照RGB通道顺序以及每一通道按照行的顺序已排列好,一个训练样本对应一行有32x32x3=3072个值。用以上的unpickle函数处理后会返回一个字典dict,dict中的key中有数据和标签,value为对应的值。如dict[b’data’]返回一个unit8的10000x3072维numpy的array。dict[b’labels’]返回一个长度为10000的list,list中的值为0-9,每个数字表示一个类别,具体对应类别可由batchs.meta获取。
采用TensorFlow加载cifar 10数据集(推荐)
1、下载cifar 10数据集数据集(下载Python版本数据集)。下载链接:http://www.cs.toronto.edu/~kriz/cifar.html
2、修改文件名。将原文件名cifar-10-python.tar.gz改成cifar-10-batches-py.tar.gz
3、移动文件位置。将修改名字后的文件移动到 C:\Users{你的用户名}.keras\datasets
4、采用TensorFlow读取数据集
import tensorflow as tf
import numpy as np
#加载cifar-10数据集
def dataLoad():
(train_data, train_label), (test_data, test_label) = tf.keras.datasets.cifar10.load_data()
#数据集进行归一化
train_data=train_data/255
test_data=test_data/255
#将标签数据集从数组类型array修改成整形类型int
train_label.astype(np.int)
test_label.astype(np.int)
return (train_data, train_label), (test_data, test_label)
1、下载cifar 10数据集数据集(下载Python版本数据集)。下载链接:http://www.cs.toronto.edu/~kriz/cifar.html
将下载好的数据集解压到相对路径中,文件名为cifar-10-batches-py
2、采用官网读取数据集的方法
数据集官网上提供了python3读取CIFAR-10的方式,以下函数可以将数据集转化为字典类型:
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
3、合并数据集
import numpy as np
import os
#创建训练样本和测试样本
def CreatData():
#创建训练样本
#依次加载batch_data_i,并合并到x,y
x=[]
y=[]
for i in range(1,6):
batch_path='cifar-10-batches-py\data_batch_%d'%(i)
batch_dict=unpickle(batch_path)
train_batch=batch_dict[b'data'].astype('float')
train_labels=np.array(batch_dict[b'labels'])
x.append(train_batch)
y.append(train_labels)
#将5个训练样本batch合并为50000x3072,标签合并为50000x1
#np.concatenate默认axis=0,为纵向连接
traindata=np.concatenate(x)
trainlabels=np.concatenate(y)
#创建测试样本
#直接写cifar-10-batches-py\test_batch会报错,因此把/t当作制表符了,应用\\;
# test_dict=unpickle("cifar-10-batches-py\\test_batch")
#建议使用os.path.join()函数
testpath=os.path.join('cifar-10-batches-py','test_batch')
test_dict=unpickle(testpath)
testdata=test_dict[b'data'].astype('float')
testlabels=np.array(test_dict[b'labels'])
return traindata,trainlabels,testdata,testlabels