读取cifar数据集

读取cifar数据集

在学习深度学习的过程中必须要用到数据集对模型进行训练,本文主要介绍如何读取cifar数据集。
cifar数据集的下载地址为:CIFAR官网,下载速度较慢
百度云下载地址为:百度云下载地址
下载好后解开压缩包。
1.用python3来读取cifar文件

   import pickle
    def unpickle(self,f):
        fo = open(f, 'rb')
        d = pickle.load(fo,encoding='latin1')
        fo.close()
        return d

此时读取的文件是一个字典,你可以查看关键字

print(d.key())

然后,根据关键字将有用的信息提取出来

data = d['data']
labels = d['labels']

读取后的data数据并不是一张张32323的图片,需要进行转换,如下。

new = data.reshape(10000,3,32,32)
#将[10000][3][32][32]转为[10000][32][32][3]
imgs = new.transpose((0,2,3,1))

此时可以作为机器学习的输入数据了。来看看图片是什么样的吧!

import matplotlib.pyplot as plt
plt.imshow(pic_test[1000])
plt.legend()
plt.show()

附上一个比较完整的代码(在python3上运行)

# -*- coding: utf-8 -*-
"""
Created on Wed Sep  4 20:21:56 2019

@author: ASUS
"""

import pickle
import numpy as np
import os
 
class Cifar10DataReader():
    def __init__(self,cifar_folder,onehot=True):
        self.cifar_folder=cifar_folder
        self.onehot=onehot
        self.data_index=1
        self.read_next=True
        self.data_label_train=None
        self.data_label_test=None
        self.batch_index=0
        
    def unpickle(self,f):
        fo = open(f, 'rb')
        d = pickle.load(fo,encoding='latin1')
        fo.close()
        return d
    
    def next_train_data(self,batch_size=100):
        assert 10000%batch_size==0,"10000%batch_size!=0"
        rdata=None
        rlabel=None
        if self.read_next:
            f=os.path.join(self.cifar_folder,"data_batch_%s"%(self.data_index))
            #print 'read: %s'%f
            dic_train=self.unpickle(f)
            self.data_label_train=list(zip(dic_train['data'],dic_train['labels']))#label 0~9
            np.random.shuffle(self.data_label_train)
            
            self.read_next=False
            if self.data_index==5:
                self.data_index=1
            else: 
                self.data_index+=1
        
        if self.batch_index

你可能感兴趣的:(深度学习)