用python读取cifar-10与cifar-100图像数据

有很多机器学习的公开数据都需要手工编码读取,当然自己写代码读取是机器学习应用的基本能力,这里为了大家方便开发代码,避免重复发明轮子。

关于cifar数据集,点击这里,因为其下载比较慢,所以可以用csdn的下载地址下载cifar-10,cifar-10 csdn地址

下载后将其解压,如路径为: /xxx/cifar-10-batches-py/

代码很简单没有写注释,读取代码如下:

import cPickle
import numpy as np
import os

class Cifar10DataReader():
    def __init__(self,cifar_folder,onehot=True):
        self.cifar_folder=cifar_folder
        self.onehot=onehot
        self.data_index=1
        self.read_next=True
        self.data_label_train=None
        self.data_label_test=None
        self.batch_index=0
        
    def unpickle(self,f):
        fo = open(f, 'rb')
        d = cPickle.load(fo)
        fo.close()
        return d
    
    def next_train_data(self,batch_size=100):
        assert 10000%batch_size==0,"10000%batch_size!=0"
        rdata=None
        rlabel=None
        if self.read_next:
            f=os.path.join(self.cifar_folder,"data_batch_%s"%(self.data_index))
            print 'read: %s'%f
            dic_train=self.unpickle(f)
            self.data_label_train=zip(dic_train['data'],dic_train['labels'])#label 0~9
            np.random.shuffle(self.data_label_train)
            
            self.read_next=False
            if self.data_index==5:
                self.data_index=1
            else: 
                self.data_index+=1
        
        if self.batch_index

cifar-100的数据读取(测试和cifar-10一样就不写了,这里面有coarse_labels,即:大类别,需要的话可以自己添加)



import cPickle
import numpy as np
import os

class Cifar100DataReader():
    def __init__(self,cifar_folder,onehot=True):
        self.cifar_folder=cifar_folder
        self.onehot=onehot
        self.data_label_train=None
        self.data_label_test=None
        self.batch_index=0
        f=os.path.join(self.cifar_folder,"train")
        print 'read: %s'%f
        dic_train=unpickle(f)
        self.data_label_train=zip(dic_train['data'],dic_train['fine_labels'])#label 0~99
        np.random.shuffle(self.data_label_train)
        
        
    def next_train_data(self,batch_size=100):
        """
        cifar100 data content:
            {
            "coarse_labels":[0,...,19],#0~19 super category
            "filenames":["volcano_s_000012.png",...],
            "batch_label":"",
            "fine_labels":[0,1...99]#0~99 category
            }
        return list of numpy arrays [na,...,na] with specific batch_size
                na: N dimensional numpy array 
        """
        
        if self.batch_index




你可能感兴趣的:(machine,learning,computer,language)