cifar-10数据集处理

cifar-10数据集处理

CIFAR-10数据集由10个类的60000个32x32彩色图像组成,每个类有6000个图像。有50000个训练图像和10000个测试图像。训练图像分为5个批次,测试图像分为1个批次。

python代码如下:

import numpy as np
import random
import pickle
import platform
import os

#加载序列文件
def load_pickle(f):
    version=platform.python_version_tuple()#判断python的版本
    if version[0]== '2':
        return pickle.load(f)
    elif version[0]== '3':
        return pickle.load(f,encoding='latin1')
    raise ValueError("invalid python version:{}".format(version))
#处理原数据
def load_CIFAR_batch(filename):
    with open(filename,'rb') as f:
        datadict=load_pickle(f)
        X=datadict['data']
        Y=datadict['labels']
        X=X.reshape(10000,3,32,32).transpose(0,2,3,1).astype("float")
        #reshape()是在不改变矩阵的数值的前提下修改矩阵的形状,transpose()对矩阵进行转置
        Y=np.array(Y)
        return X,Y

#返回可以直接使用的数据集
def load_CIFAR10(ROOT):
    xs=[]
    ys=[]
    for b in range(1,6):
        f=os.path.join(ROOT,'data_batch_%d'%(b,))#os.path.join()将多个路径组合后返回
        X,Y=load_CIFAR_batch(f)
        xs.append(X)
        ys.append(Y)
    Xtr=np.concatenate(xs)#这个函数用于将多个数组进行连接
    Ytr=np.concatenate(ys)
    del X,Y
    Xte,Yte=load_CIFAR_batch(os.path.join(ROOT,'test_batch'))
    return Xtr,Ytr,Xte,Yte
datasets = './cifar-10-batches-py'
train_x,train_y,test_x,test_y = load_CIFAR10(datasets)
print('train_x shape:%s, train_y shape:%s' % (train_x.shape, train_y.shape))
print('test_x shape:%s, test_y shape:%s' % (test_x.shape, test_y.shape))

输出结果为:

train_x shape:(50000, 32, 32, 3), train_y shape:(50000,)
test_x shape:(10000, 32, 32, 3), test_y shape:(10000,)

其中50000为图片数量,32323为图像的width,hight,channel

下篇为cifar-100数据集处理。

你可能感兴趣的:(cifar-10数据集处理)