用python读取cifar-10与cifar-100图像数据


有很多机器学习的公开数据都需要手工编码读取,当然自己写代码读取是机器学习应用的基本能力,这里为了大家方便开发代码,避免重复发明轮子。

关于cifar数据集,点击这里,因为其下载比较慢,所以可以用csdn的下载地址下载cifar-10,cifar-10 csdn地址

下载后将其解压,如路径为: /xxx/cifar-10-batches-py/

代码很简单没有写注释,读取代码如下:

[python] view plain copy
  1. import cPickle  
  2. import numpy as np  
  3. import os  
  4.   
  5. class Cifar10DataReader():  
  6.     def __init__(self,cifar_folder,onehot=True):  
  7.         self.cifar_folder=cifar_folder  
  8.         self.onehot=onehot  
  9.         self.data_index=1  
  10.         self.read_next=True  
  11.         self.data_label_train=None  
  12.         self.data_label_test=None  
  13.         self.batch_index=0  
  14.           
  15.     def unpickle(self,f):  
  16.         fo = open(f, 'rb')  
  17.         d = cPickle.load(fo)  
  18.         fo.close()  
  19.         return d  
  20.       
  21.     def next_train_data(self,batch_size=100):  
  22.         assert 10000%batch_size==0,"10000%batch_size!=0"  
  23.         rdata=None  
  24.         rlabel=None  
  25.         if self.read_next:  
  26.             f=os.path.join(self.cifar_folder,"data_batch_%s"%(self.data_index))  
  27.             print 'read: %s'%f  
  28.             dic_train=self.unpickle(f)  
  29.             self.data_label_train=zip(dic_train['data'],dic_train['labels'])#label 0~9  
  30.             np.random.shuffle(self.data_label_train)  
  31.               
  32.             self.read_next=False  
  33.             if self.data_index==5:  
  34.                 self.data_index=1  
  35.             else:   
  36.                 self.data_index+=1  
  37.           
  38.         if self.batch_indexself.data_label_train)//batch_size:  
  39.             #print self.batch_index  
  40.             datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]  
  41.             self.batch_index+=1  
  42.             rdata,rlabel=self._decode(datum,self.onehot)  
  43.         else:  
  44.             self.batch_index=0  
  45.             self.read_next=True  
  46.             return self.next_train_data(batch_size=batch_size)  
  47.               
  48.         return rdata,rlabel  
  49.       
  50.     def _decode(self,datum,onehot):  
  51.         rdata=list();rlabel=list()  
  52.         if onehot:  
  53.             for d,l in datum:  
  54.                 rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))  
  55.                 hot=np.zeros(10)  
  56.                 hot[int(l)]=1  
  57.                 rlabel.append(hot)  
  58.         else:  
  59.             for d,l in datum:  
  60.                 rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))  
  61.                 rlabel.append(int(l))  
  62.         return rdata,rlabel  
  63.               
  64.     def next_test_data(self,batch_size=100):  
  65.         if self.data_label_test is None:  
  66.             f=os.path.join(self.cifar_folder,"test_batch")  
  67.             print 'read: %s'%f  
  68.             dic_test=self.unpickle(f)  
  69.             data=dic_test['data']  
  70.             labels=dic_test['labels']#0~9  
  71.             self.data_label_test=zip(data,labels)  
  72.           
  73.         np.random.shuffle(self.data_label_test)  
  74.         datum=self.data_label_test[0:batch_size]  
  75.           
  76.         return self._decode(datum,self.onehot)  
  77.   
  78. if __name__=="__main__":  
  79.     dr=Cifar10DataReader(cifar_folder="/xxx/cifar-10-batches-py/")  
  80.     import matplotlib.pyplot as plt  
  81.     d,l=dr.next_test_data()  
  82.     print np.shape(d),np.shape(l)  
  83.     plt.imshow(d[0])  
  84.     plt.show()  
  85.     for i in xrange(600):  
  86.         d,l=dr.next_train_data(batch_size=100)  
  87.         print np.shape(d),np.shape(l)  
  88.    

cifar-100的数据读取(测试和cifar-10一样就不写了,这里面有coarse_labels,即:大类别,需要的话可以自己添加)


[python] view plain copy
  1. import cPickle  
  2. import numpy as np  
  3. import os  
  4.   
  5. class Cifar100DataReader():  
  6.     def __init__(self,cifar_folder,onehot=True):  
  7.         self.cifar_folder=cifar_folder  
  8.         self.onehot=onehot  
  9.         self.data_label_train=None  
  10.         self.data_label_test=None  
  11.         self.batch_index=0  
  12.         f=os.path.join(self.cifar_folder,"train")  
  13.         print 'read: %s'%f  
  14.         dic_train=unpickle(f)  
  15.         self.data_label_train=zip(dic_train['data'],dic_train['fine_labels'])#label 0~99  
  16.         np.random.shuffle(self.data_label_train)  
  17.           
  18.           
  19.     def next_train_data(self,batch_size=100):  
  20.         """ 
  21.         cifar100 data content: 
  22.             { 
  23.             "coarse_labels":[0,...,19],#0~19 super category 
  24.             "filenames":["volcano_s_000012.png",...], 
  25.             "batch_label":"", 
  26.             "fine_labels":[0,1...99]#0~99 category 
  27.             } 
  28.         return list of numpy arrays [na,...,na] with specific batch_size 
  29.                 na: N dimensional numpy array  
  30.         """  
  31.           
  32.         if self.batch_indexself.data_label_train)//batch_size:  
  33.             #print self.batch_index  
  34.             datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]  
  35.             self.batch_index+=1  
  36.             return self._decode(datum,self.onehot)  
  37.         else:  
  38.             self.batch_index=0  
  39.             np.random.shuffle(self.data_label_train)  
  40.             datum=self.data_label_train[self.batch_index*batch_size:(self.batch_index+1)*batch_size]  
  41.             self.batch_index+=1  
  42.             return self._decode(datum,self.onehot)  
  43.               
  44.       
  45.     def _decode(self,datum,onehot):  
  46.         rdata=list();rlabel=list()  
  47.         if onehot:  
  48.             for d,l in datum:  
  49.                 rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))  
  50.                 hot=np.zeros(100)  
  51.                 hot[int(l)]=1  
  52.                 rlabel.append(hot)  
  53.         else:  
  54.             for d,l in datum:  
  55.                 rdata.append(np.reshape(np.reshape(d,[3,1024]).T,[32,32,3]))  
  56.                 rlabel.append(int(l))  
  57.         return rdata,rlabel  
  58.               
  59.     def next_test_data(self,batch_size=100):  
  60.         ''''' 
  61.         return list of numpy arrays [na,...,na] with specific batch_size 
  62.                 na: N dimensional numpy array  
  63.         '''  
  64.         if self.data_label_test is None:  
  65.             f=os.path.join(self.cifar_folder,"test")  
  66.             print 'read: %s'%f  
  67.             dic_test=unpickle(f)  
  68.             data=dic_test['data']  
  69.             #print len(dic_test["coarse_labels"])  
  70.             #print len(dic_test["filenames"])  
  71.             labels=dic_test['fine_labels']#0~99  
  72.             self.data_label_test=zip(data,labels)  
  73.               
  74.         np.random.shuffle(self.data_label_test)  
  75.         datum=self.data_label_test[0:batch_size]  
  76.           
  77.         return self._decode(datum,self.onehot)  



你可能感兴趣的:(【,Python相关,】)