CIFAR-10数据下载后,我们可以看看这个结构里是什么内容,用下边的代码可以随机生成一个。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import os
import random
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo)
return dict
def get_data(file):
absFile = os.path.abspath(file)
dict = unpickle(absFile)
print(dict)
X = np.asarray(dict[b'data'].T).astype("uint8")
Yraw = np.asarray(dict[b'labels'])
Y = np.zeros((10,10000))
for i in range(10000):
Y[Yraw[i],i] = 1
names = np.asarray(dict[b'filenames'])
return X,Y,names
def visualize_image(X,Y,names,id):
rgb = X[:,id]
img = rgb.reshape(3,32,32).transpose([1, 2, 0])
plt.imshow(img)
plt.title(names[id])
dir = os.path.abspath("./")
plt.savefig(dir+"/"+names[id].decode('ascii'))
X,Y,names = get_data('./data_batch_1')
visualize_image(X,Y,names,random.randint(1,10000))
方法二:
import mxnet as mx
import numpy as np
import cPickle
import cv2
def extractImagesAndLabels(path, file):
f = open(path+file, 'rb')
dict = cPickle.load(f)
images = dict['data']
images = np.reshape(images, (10000, 3, 32, 32))
labels = dict['labels']
imagearray = mx.nd.array(images)
labelarray = mx.nd.array(labels)
return imagearray, labelarray
def extractCategories(path, file):
f = open(path+file, 'rb')
dict = cPickle.load(f)
return dict['label_names']
def saveCifarImage(array, path, file):
# array is 3x32x32. cv2 needs 32x32x3
array = array.asnumpy().transpose(1,2,0)
# array is RGB. cv2 needs BGR
array = cv2.cvtColor(array, cv2.COLOR_RGB2BGR)
# save to PNG file
return cv2.imwrite(path+file+".png", array)
imgarray, lblarray = extractImagesAndLabels("./", "test_batch")
print imgarray.shape
print lblarray.shape
categories = extractCategories("./", "batches.meta")
cats = []
for i in range(0,10):
saveCifarImage(imgarray[i], "./", "image"+(str)(i))
category = lblarray[i].asnumpy()
category = (int)(category[0])
cats.append(categories[category])
print cats