python读取cifar10并显示

最近在做基于哈希的图像检索,我准备使用cifar10数据集,用来训练哈希函数并测试。我想用python把cifar10里的数据读出,斌且显示出出图片。下面是我的代码。

有在做基于哈希的近邻搜索的朋友,可以交流交流。

import numpy as np
import os
from matplotlib import pyplot as plt
import pickle

data_dir = "C:\\Users\\25806\\pyenv\\testenv\\hash\\hashsuanfa"
data_dir_cifar10 = os.path.join(data_dir,"cifar-10-batches-py")
class_name_cifar10 = np.load(os.path.join(data_dir_cifar10,"batches.meta"))
def load_batch_cifar10(filename,dtype="float 64"):
    path = os.path.join(data_dir_cifar10,filename)#链接字符串,合成文件路径
    fi = open(path, 'rb')  # 打开文件
    batch = pickle.load(fi, encoding="bytes")  # 读入数据
    fi.close()
    data = batch[b'data']/255.0
    labels = batch[b'labels']#每一个数据的标签
    return data,labels#返回标签矩阵
def load_cifar10():
    x_train = []#存放训练数据,最终是50000*3072的矩阵
    y_train = []
    for i in range(5):#读取五个文件
        x,t = load_batch_cifar10("data_batch_%d"%(i+1))
        x_train.append(x)
        y_train.append(t)
    x_test ,y_test= load_batch_cifar10("test_batch")#读取测试文件
    x_train = np.concatenate(x_train,axis=0)#将五个文件的矩阵合成一个
    y_train = np.concatenate(y_train, axis=0)

    x_train = x_train.reshape(x_train.shape[0],3,32,32)
    x_test = x_test.reshape(x_test.shape[0],3,32,32)
    return x_train,y_train,x_test,y_test
Xtrain,Ytrain,Xtest,Ytest = load_cifar10()
imlist = []
for i in range(24): #显示24张图片
    red = Xtrain[i][0].reshape(1024,1)
    green = Xtrain[i][1].reshape(1024,1)
    blue = Xtrain[i][2].reshape(1024,1)
    pic = np.hstack((red,green,blue))
    pic_grab = pic.reshape(32,32,3)#合成一个三维矩阵,每一个点包含红绿蓝三种颜色
    imlist.append(pic_grab)
fig = plt.figure()
for j in range(1,25):
    ax = fig.add_subplot(4,6,j)#这三个参数是,图片行数,列数,编号
    plt.title(class_name_cifar10['label_names'][Ytrain[j-1]])
    plt.axis('off')#不显示坐标值
    plt.imshow(imlist[j-1])#显示图片
plt.subplots_adjust(wspace=0,hspace=0)
plt.show()

你可能感兴趣的:(cifar10)