cifar10数据集的读取

cifar10数据集----加载数据

import pickle

def load_cifar10_batch(cifar10_dataset_folder_path,batch_id):

    with open(cifar10_dataset_folder_path + '/data_batch_' + str(batch_id),mode='rb') as file:
        batch = pickle.load(file, encoding = 'latin1')
        
    # features and labels
    features = batch['data'].reshape((len(batch['data']),3,32,32)).transpose(0,2,3,1)
    labels = batch['labels']
    
    return  features, labels        
# 加载所有数据
cifar10_path = '/root/zhj/python3/code/data/cifar-10-batches-py'
# 一共有5个batch的训练数据
x_train, y_train = load_cifar10_batch(cifar10_path, 1)
for i in range(2,6):
    features,labels = load_cifar10_batch(cifar10_path, i)
    x_train, y_train = np.concatenate([x_train, features]),np.concatenate([y_train, labels])

# 加载测试数据
with open(cifar10_path + '/test_batch', mode = 'rb') as file:
    batch = pickle.load(file, encoding='latin1')
    x_test = batch['data'].reshape((len(batchtch['data']),3,32,32)).transpose(0,2,3,1)
    y_test = batch['labels']
---------------------------------------------------------------------------

UnpicklingError                           Traceback (most recent call last)

 in ()
      1 # 加载测试数据
      2 with open(cifar10_path + '/test_batch', mode = 'rb') as file:
----> 3     batch = pickle.load(file, encoding='latin1')
      4     x_test = batch['data'].reshape((len(batchtch['data']),3,32,32)).transpose(0,2,3,1)
      5     y_test = batch['labels']


UnpicklingError: invalid load key, '\xfe'.

问题:数据导入是以bytes形式导入的,而在load时出现了无效的字符’\xfe’

解决办法:测试数据上传至云服务器时出现错误,重新上传一次

# 加载测试数据
with open(cifar10_path + '/test_batch', mode = 'rb') as file:
    batch = pickle.load(file, encoding='latin1')
    x_test = batch['data'].reshape((len(batch['data']),3,32,32)).transpose(0,2,3,1)
    y_test = batch['labels']

显示图片

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
#%matplotlib inline


fig,axes = plt.subplots(nrows=3,ncols=20, sharex=True, sharey=True, figsize=(80,12))
# subplot(numRows, numCols, plotNum)  numRows, numCols表示绘图区域被分成 numRows 行和 numCols 列
# 然后按照从左到右,从上到下的顺序对每个子区域进行编号,左上的子区域的编号为1
# plotNum 参数指定创建的 Axes 对象所在的区域
# 如果 numRows, numCols 和 plotNum 这三个数都小于 10 的话, 可以把它们缩写为一个整数, 例如 subplot(323) 和 subplot(3,2,3) 是相同的

imgs = x_train[:60]

# zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。
# 如果各个迭代器的元素个数不一致,则返回列表长度与最短的对象相同,利用 * 号操作符,可以将元组解压为列表。

for imgs, row in zip([imgs[:20],imgs[20:40],imgs[40:60]],axes):
    for img, ax in zip(imgs, row):
        ax.imshow(img)
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)   
        # plt.axis('off') #关闭xy坐标轴
        # frame.axes.get_yaxis().set_visible(False)  #不显示y轴        
        # frame.axes.get_xaxis().set_visible(False) #不显示x轴

fig.tight_layout(pad=0.1)
# 在 matplotlib 中,轴域(包括子图)的位置以标准化图形坐标指定。 可能发生的是,你的轴标签或标题(有时甚至是刻度标签)会超出图形区域,因此被截断
# 为了避免它,轴域的位置需要调整。对于子图,这可以通过调整子图参数(移动轴域的一条边来给刻度标签腾地方)。Matplotlib v1.1 引入了一个新的命令tight_layout(),自动为你解决这个问题。
# 当拥有多个子图时,会经常看到不同轴域的标签叠在一起
# tight_layout()也会调整子图之间的间隔来减少堆叠

plt.show() 
# 一定要加这个图片才会都显示出来,仅仅在大图里面显示子图示无法显示的。要再显示子图之后调用一次大图(fig)的显示

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-N0lXdjoV-1586522327406)(output_7_0.png)]

你可能感兴趣的:(机器学习)