CIFAR-10 数据的显示 及 opencv保存

CIFAR-10数据下载后,我们可以看看这个结构里是什么内容,用下边的代码可以随机生成一个。

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import os
import random

def unpickle(file):
   import pickle
   with open(file, 'rb') as fo:
      dict = pickle.load(fo)
   return dict

def get_data(file):
   absFile = os.path.abspath(file)
   dict = unpickle(absFile)

   print(dict)

   X = np.asarray(dict[b'data'].T).astype("uint8")
   Yraw = np.asarray(dict[b'labels'])
   Y = np.zeros((10,10000))
   for i in range(10000):
      Y[Yraw[i],i] = 1
   names = np.asarray(dict[b'filenames'])
   return X,Y,names

def visualize_image(X,Y,names,id):
   rgb = X[:,id]

   img = rgb.reshape(3,32,32).transpose([1, 2, 0])

   plt.imshow(img)
   plt.title(names[id])

   dir = os.path.abspath("./")
   plt.savefig(dir+"/"+names[id].decode('ascii'))


X,Y,names = get_data('./data_batch_1')
visualize_image(X,Y,names,random.randint(1,10000))

方法二:

import mxnet as mx
import numpy as np
import cPickle
import cv2

def extractImagesAndLabels(path, file):
    f = open(path+file, 'rb')
    dict = cPickle.load(f)
    images = dict['data']
    images = np.reshape(images, (10000, 3, 32, 32))
    labels = dict['labels']
    imagearray = mx.nd.array(images)
    labelarray = mx.nd.array(labels)
    return imagearray, labelarray

def extractCategories(path, file):
    f = open(path+file, 'rb')
    dict = cPickle.load(f)
    return dict['label_names']

def saveCifarImage(array, path, file):
    # array is 3x32x32. cv2 needs 32x32x3
    array = array.asnumpy().transpose(1,2,0)
    # array is RGB. cv2 needs BGR
    array = cv2.cvtColor(array, cv2.COLOR_RGB2BGR)
    # save to PNG file
    return cv2.imwrite(path+file+".png", array)

imgarray, lblarray = extractImagesAndLabels("./", "test_batch")
print imgarray.shape
print lblarray.shape

categories = extractCategories("./", "batches.meta")

cats = []
for i in range(0,10):
    saveCifarImage(imgarray[i], "./", "image"+(str)(i))
    category = lblarray[i].asnumpy()
    category = (int)(category[0])
    cats.append(categories[category])
print cats

 

你可能感兴趣的:(CIFAR-10 数据的显示 及 opencv保存)