本教程以读取Fashion-MNIST为例
fashion-mnist
共下载四个文件
本人将所有文件保存到此文档路径中:'/home/brian/Documents/tensorflow-gpu/tensorflow-learning/data/fashion/'
(注意替换程序中的path为你保存文件的所在文件夹的路径, files为数据集文件名)
# 引入必要的库函数
from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.util.tf_export import tf_export
import gzip
# 读取本地gz文档,并转换为numpy矩阵的函数
def load_localData():
path = '/home/brian/Documents/tensorflow-gpu/tensorflow-learning/data/fashion/'
files = [
'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz']
paths = []
for fname in files:
paths.append(get_file(fname, origin=None, cache_dir=path + fname, cache_subdir=path))
with gzip.open(paths[0], 'rb') as lbpath:
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[1], 'rb') as imgpath:
x_train = np.frombuffer(\
imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
with gzip.open(paths[2], 'rb') as lbpath:
y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[3], 'rb') as imgpath:
x_test = np.frombuffer(\
imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)
return x_train, y_train, x_test, y_test
x_train, y_train, x_test, y_test = load_localData()
print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)
(60000, 28, 28)
(60000,)
(10000, 28, 28)
(10000,)