关是读取数据,对于小白来说就花了不少时间,来总结以下。
首先你的数据从网上下的,一定要看清楚对于的是什么版,我就在这上面吃了大亏,之前是用的cifar10的模块,用的别人的包自动下载的,得到的是Bin文件,这个是二进制的文件,试用于c语言的,结果我用Python的pickle包Load半天老是出问题。
数据集下载的网址是:http://www.cs.toronto.edu/~kriz/cifar.html
一定要下对版本!!
然后读取顺利就可以看到了 ,这是网站上给的python3的读取方式
import pickle
def unpick(f):
with open(f,'rb') as fo:
dic=pickle.load(fo,encoding='bytes')
return dic
得到的是个字典,注意字典的索引是字节型的,比如要读取data,那么应该是dic[b'data']
,字符串前面加b 才是字节
另外参考这篇https://blog.csdn.net/u010165147/article/details/54176612
上面的代码,有一些错误,我修改了下,亲测有效现在放上来我修改后的版本:
import pickle
import numpy as np
import os
class Cifar10DataReader():
def __init__(self, cifar_folder, onehot=True):
self.cifar_folder = cifar_folder
self.onehot = onehot
self.data_index = 1
self.read_next = True
self.data_label_train = None
self.data_label_test = None
self.batch_index = 0
def unpickle(self, f):
fo = open(f, 'rb')
d = pickle.load(fo,encoding='bytes')
fo.close()
return d
def next_train_data(self, batch_size=100):
assert 10000 % batch_size == 0, "10000%batch_size!=0"
rdata = None
rlabel = None
if self.read_next:
f = os.path.join(self.cifar_folder, "data_batch_%s" % (self.data_index))
print('read: %s' % f)
dic_train = self.unpickle(f)
self.data_label_train = list(zip(dic_train[b'data'], dic_train[b'labels'])) # label 0~9
np.random.shuffle(self.data_label_train)
self.read_next = False
if self.data_index == 5:
self.data_index = 1
else:
self.data_index += 1
if self.batch_index < len(list(self.data_label_train)) // batch_size:
# print self.batch_index
datum = self.data_label_train[self.batch_index * batch_size:(self.batch_index + 1) * batch_size]
self.batch_index += 1
rdata, rlabel = self._decode(datum, self.onehot)
else:
self.batch_index = 0
self.read_next = True
return self.next_train_data(batch_size=batch_size)
return rdata, rlabel
def _decode(self, datum, onehot):
rdata = list();
rlabel = list()
if onehot:
for d, l in datum:
rdata.append(np.reshape(np.reshape(d, [3, 1024]).T, [32, 32, 3]))
hot = np.zeros(10)
hot[int(l)] = 1
rlabel.append(hot)
else:
for d, l in datum:
rdata.append(np.reshape(np.reshape(d, [3, 1024]).T, [32, 32, 3]))
rlabel.append(int(l))
return rdata, rlabel
def next_test_data(self, batch_size=100):
if self.data_label_test is None:
f = os.path.join(self.cifar_folder, "test_batch")
print('read: %s' % f)
dic_test = self.unpickle(f)
data = dic_test[b'data']
labels = dic_test[b'labels'] # 0~9
self.data_label_test = list(zip(data,labels))
np.random.shuffle(self.data_label_test)
datum = self.data_label_test[0:batch_size]
return self._decode(datum, self.onehot)
if __name__ == "__main__":
dr = Cifar10DataReader(cifar_folder="E:\Tensorlow\Project\深度学习练习\cifar-10-batches-py\\")
import matplotlib.pyplot as plt
d, l = dr.next_test_data()
print(np.shape(d), np.shape(l))
plt.imshow(d[2])
plt.show()
# for i in range(600):
# d, l = dr.next_train_data(batch_size=100)
# print(np.shape(d), np.shape(l))
得到的图片
------------------------------------------------------------------------------------------------------------------------------------------------------------
嘀嘀嘀,前天测试的代码,我昨天又自己重新写了一遍,并且修改了一些,觉得之前的代码有一部分写的不是很好,复用率不高,比如数据文件在每一个批次的训练都需要重新加载,这样感觉效率会大大降低,所以我就做了些改进以及增加了对象的扩展性,可以调用读取数据的函数,得到未加工的向量形式的数据,而不是只能得到一个批次的数据张量。并且打上了备注(本人水平不高,备注可能也就自己能看懂,还望见谅),下面附上代码。
class Cifar10DataReader():
import os
import random
import numpy as np
import pickle
def __init__(self, cifar_file, one_hot=False, file_number=1):
self.batch_index = 0 # 第i批次
self.file_number = file_number # 第i个文件数
self.cifar_file = cifar_file # 数据集所在dir
self.one_hot = one_hot
self.train_data = self.read_train_file() # 一个数据文件的训练集数据,得到的是一个1000大小的list,
self.test_data = self.read_test_data() # 得到1000个测试集数据
# 读取数据函数,返回dict
def unpickle(self, file):
with open(file, 'rb') as fo:
try:
dicts = self.pickle.load(fo, encoding='bytes')
except Exception as e:
print('load error', e)
return dicts
# 读取一个训练集文件,返回数据list
def read_train_file(self, files=''):
if files:
files = self.os.path.join(self.cifar_file, files)
else:
files = self.os.path.join(self.cifar_file, 'data_batch_%d' % self.file_number)
dict_train = self.unpickle(files)
train_data = list(zip(dict_train[b'data'], dict_train[b'labels'])) # 将数据和对应标签打包
self.np.random.shuffle(train_data)
print('成功读取到训练集数据:data_batch_%d' % self.file_number)
return train_data
# 读取测试集数据
def read_test_data(self):
files = self.os.path.join(self.cifar_file, 'test_batch')
dict_test = self.unpickle(files)
test_data = list(zip(dict_test[b'data'], dict_test[b'labels'])) # 将数据和对应标签打包
print('成功读取测试集数据')
return test_data
# 编码得到的数据,变成张量,并分别得到数据和标签
def encodedata(self, detum):
rdatas = list()
rlabels = list()
for d, l in detum:
rdatas.append(self.np.reshape(self.np.reshape(d, [3, 1024]).T, [32, 32, 3]))
if self.one_hot:
hot = self.np.zeros(10)
hot[int(l)] = 1
rlabels.append(hot)
else:
rlabels.append(l)
return rdatas, rlabels
# 得到batch_size大小的数据和标签
def nex_train_data(self, batch_size=100):
assert 1000 % batch_size == 0, 'erro batch_size can not divied!' # 判断批次大小是否能被整除
# 获得一个batch_size的数据
if self.batch_index < len(self.train_data) // batch_size: # 是否超出一个文件的数据量
detum = self.train_data[self.batch_index * batch_size:(self.batch_index + 1) * batch_size]
datas, labels = self.encodedata(detum)
self.batch_index += 1
else: # 超出了就加载下一个文件
self.batch_index = 0
if self.file_number == 5:
self.file_number = 1
else:
self.file_number += 1
self.read_train_file()
return self.nex_train_data(batch_size=batch_size)
return datas, labels
# 随机抽取batch_size大小的训练集
def next_test_data(self, batch_size=100):
detum = self.random.sample(self.test_data, batch_size) # 随机抽取
datas, labels = self.encodedata(detum)
return datas, labels
if __name__ == '__main__':
import matplotlib.pyplot as plt
Cifar10 = Cifar10DataReader(r'E:\Tensorlow\Project\深度学习练习\cifar-10-batches-py', one_hot=True)
d, l = Cifar10.nex_train_data()
print(len(d))
print(d)
plt.imshow(d[0])
plt.show()