CIFAR-10数据集下载及转换

keras中提供的cifar10数据集可能因为网速等问题无法直接下载读取,可以进入官网下载到本地,网址:
http://www.cs.toronto.edu/~kriz/cifar.html,
这里我们下载python版本的。
将下载的tar.gz形式的文件解压,放到想要存放数据文件的文件夹中,这里我的文件存放位置为"/Users/shiruihuo/Documents/study/深度学习/data/cifar10/cifar-10-batches-py"。使用以下脚本可以正确的转换train和test的数据及标签。

# -*- coding: utf-8 -*-
import pickle as p
import numpy as np
import os


def load_CIFAR_batch(filename):
    """ 载入cifar数据集的一个batch """
    with open(filename, 'rb') as f:
        datadict = p.load(f, encoding='bytes')
        X = datadict[b'data']
        Y = datadict[b'labels']
        X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
        Y = np.array(Y)
        return X, Y


def load_CIFAR10(ROOT):
    """ 载入cifar全部数据 """
    xs = []
    ys = []
    for b in range(1, 6):
        f = os.path.join(ROOT, 'data_batch_%d' % (b,))
        X, Y = load_CIFAR_batch(f)
        xs.append(X)         #将所有batch整合起来
        ys.append(Y)
    Xtr = np.concatenate(xs) # 使变成行向量,最终Xtr的尺寸为(50000,32,32,3)
    Ytr = np.concatenate(ys)
    del X, Y
    Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
    return Xtr, Ytr, Xte, Yte

import numpy as np
#from julyedu.data_utils import load_CIFAR10
import matplotlib.pyplot as plt

# plt.rcParams['figure.figsize'] = (10.0, 8.0)
# plt.rcParams['image.interpolation'] = 'nearest'
# plt.rcParams['image.cmap'] = 'gray'

# 载入CIFAR-10数据集
cifar10_dir = '/Users/shiruihuo/Documents/study/深度学习/data/cifar10/cifar-10-batches-py'
x_train, y_train, x_test, y_test = load_CIFAR10(cifar10_dir)

# 看看数据集中的一些样本:每个类别展示一些
print('Training data shape: ', x_train.shape)
print('Training labels shape: ', y_train.shape)
print('Test data shape: ', x_test.shape)
print('Test labels shape: ', y_test.shape)

你可能感兴趣的:(计算机视觉,keras,深度学习)