CIFAR-10数据集由10个类的60000个32x32彩色图像组成,每个类有6000个图像。有50000个训练图像和10000个测试图像。训练图像分为5个批次,测试图像分为1个批次。
python代码如下:
import numpy as np
import random
import pickle
import platform
import os
#加载序列文件
def load_pickle(f):
version=platform.python_version_tuple()#判断python的版本
if version[0]== '2':
return pickle.load(f)
elif version[0]== '3':
return pickle.load(f,encoding='latin1')
raise ValueError("invalid python version:{}".format(version))
#处理原数据
def load_CIFAR_batch(filename):
with open(filename,'rb') as f:
datadict=load_pickle(f)
X=datadict['data']
Y=datadict['labels']
X=X.reshape(10000,3,32,32).transpose(0,2,3,1).astype("float")
#reshape()是在不改变矩阵的数值的前提下修改矩阵的形状,transpose()对矩阵进行转置
Y=np.array(Y)
return X,Y
#返回可以直接使用的数据集
def load_CIFAR10(ROOT):
xs=[]
ys=[]
for b in range(1,6):
f=os.path.join(ROOT,'data_batch_%d'%(b,))#os.path.join()将多个路径组合后返回
X,Y=load_CIFAR_batch(f)
xs.append(X)
ys.append(Y)
Xtr=np.concatenate(xs)#这个函数用于将多个数组进行连接
Ytr=np.concatenate(ys)
del X,Y
Xte,Yte=load_CIFAR_batch(os.path.join(ROOT,'test_batch'))
return Xtr,Ytr,Xte,Yte
datasets = './cifar-10-batches-py'
train_x,train_y,test_x,test_y = load_CIFAR10(datasets)
print('train_x shape:%s, train_y shape:%s' % (train_x.shape, train_y.shape))
print('test_x shape:%s, test_y shape:%s' % (test_x.shape, test_y.shape))
输出结果为:
train_x shape:(50000, 32, 32, 3), train_y shape:(50000,)
test_x shape:(10000, 32, 32, 3), test_y shape:(10000,)
其中50000为图片数量,32323为图像的width,hight,channel
下篇为cifar-100数据集处理。