类间完全互斥。汽车和卡车类没有重叠。“Automobile”只包含sedans,SUVs等等。“Truck”只包含大卡车。两者都不包含皮卡车。
def unpickle(file):
import cPickle
fo = open(file, 'rb')
dict = cPickle.load(fo)
fo.close()
return dict
每个块文件包含1个带有如下元素的字典:
label_names——1个10元素的列表,给labels中的数值标签以有意义的名称。例如,label_names[0] == "airplane", label_names[1] == "automobile"等。
数据集包含100小类,每小类包含600个图像,其中有500个训练图像和100个测试图像。100类被分组为20个大类。每个图像带有1个小类的“fine”标签和1个大类“coarse”标签。
同CIFAR-10数据集Python版本。
(1)CIFAR-10数据集存放在相对文件路径data_dir_cifar10下。
(2)_load_batch_cifar10函数
该函数加载CIFAR-10格式的块文件。根据块文件名filename和相对文件路径data_dir_cifar10拼接得到块文件位置。用numpy中的load函数加载(用cPickle中的load函数也可以加载)返回batch,batch是1个字典,里面包含数据和标签。根据数据的索引'data'得到图像数据,根据标签的索引'labels'得到图像分类的标签,标签转换为one-hot编码形式,见前一篇文章对MNIST数据集的说明。最后把数据和标签中的元素的数据类型统一为dtype类型。
(3)concatenate函数
该函数当axis=0时将矩阵按行顺序从上往下摆放(列长度相等),当axis=1时将矩阵按列顺序从左往右摆放(行长度相等)。
import numpy as np
import os
import cPickle as pickle
import glob
import matplotlib.pyplot as plt
data_dir = "data"
data_dir_cifar10 = os.path.join(data_dir, "cifar-10-batches-py")
data_dir_cifar100 = os.path.join(data_dir, "cifar-100-python")
class_names_cifar10 = np.load(os.path.join(data_dir_cifar10, "batches.meta"))
class_names_cifar100 = np.load(os.path.join(data_dir_cifar100, "meta"))
def one_hot(x, n):
"""
convert index representation to one-hot representation
"""
x = np.array(x)
assert x.ndim == 1
return np.eye(n)[x]
def _load_batch_cifar10(filename, dtype='float64'):
"""
load a batch in the CIFAR-10 format
"""
path = os.path.join(data_dir_cifar10, filename)
batch = np.load(path)
data = batch['data'] / 255.0 # scale between [0, 1]
labels = one_hot(batch['labels'], n=10) # convert labels to one-hot representation
return data.astype(dtype), labels.astype(dtype)
def _grayscale(a):
print a.reshape(a.shape[0], 3, 32, 32).mean(1).reshape(a.shape[0], -1)
return a.reshape(a.shape[0], 3, 32, 32).mean(1).reshape(a.shape[0], -1)
def cifar10(dtype='float64', grayscale=True):
# train
x_train = []
t_train = []
for k in xrange(5):
x, t = _load_batch_cifar10("data_batch_%d" % (k + 1), dtype=dtype)
x_train.append(x)
t_train.append(t)
x_train = np.concatenate(x_train, axis=0)
t_train = np.concatenate(t_train, axis=0)
# test
x_test, t_test = _load_batch_cifar10("test_batch", dtype=dtype)
if grayscale:
x_train = _grayscale(x_train)
x_test = _grayscale(x_test)
return x_train, t_train, x_test, t_test
def _load_batch_cifar100(filename, dtype='float64'):
"""
load a batch in the CIFAR-100 format
"""
path = os.path.join(data_dir_cifar100, filename)
batch = np.load(path)
data = batch['data'] / 255.0
labels = one_hot(batch['fine_labels'], n=100)
return data.astype(dtype), labels.astype(dtype)
def cifar100(dtype='float64', grayscale=True):
x_train, t_train = _load_batch_cifar100("train", dtype=dtype)
x_test, t_test = _load_batch_cifar100("test", dtype=dtype)
if grayscale:
x_train = _grayscale(x_train)
x_test = _grayscale(x_test)
return x_train, t_train, x_test, t_test
Xtrain, Ytrain, Xtest, Ytest = cifar10()
################################################
# 图像样本显示
image = Xtrain[0].reshape(32, 32)
image1 = Xtrain[255].reshape(32, 32)
fig = plt.figure()
ax = fig.add_subplot(121)
plt.axis('off')
plt.title(class_names_cifar10['label_names'][list(Ytrain[0]).index(1)])
plt.imshow(image, cmap='gray')
ax = fig.add_subplot(122)
plt.title(class_names_cifar10['label_names'][list(Ytrain[255]).index(1)])
plt.imshow(image1, cmap='gray')
plt.axis('off')
plt.show()
(1)CIFAR数据集:http://www.cs.toronto.edu/~kriz/cifar.html
(2)数据集加载:https://github.com/benanne/theano-tutorial/blob/master/load.py