最近在积攒粉丝500,大家帮帮忙,动动小手指关注、点赞、收藏…
MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片,也包含了每一张图像的标签(告诉我们图片是数字几):
MNIST数据集网上下载得到的分为训练数据集和测试数据集两部分(数据集的数据都是由图片数据集和对应的标签数据集组成)。
其下载地址: http://yann.lecun.com/exdb/mnist/
下载得到包含如下几个文件:
训练集:用于训练的数据。
测试集:用于测试模型的泛化能力。
解压标签文件和数据图像文件,然后再调用tensorflow2.x的数据集API进行预处理、洗牌、分批次等。
def load_images(filename):
"""load images
filename: the name of the file containing data
return -- a matrix containing images as row vectors
"""
g_file = gzip.GzipFile(filename)
data = g_file.read()
magic, num, rows, columns = struct.unpack('>iiii', data[:16])
dimension = rows*columns
X = np.zeros((num,rows,columns), dtype='uint8')
offset = 16
for i in range(num):
a = np.frombuffer(data, dtype=np.uint8, count=dimension, offset=offset)
X[i] = a.reshape((rows, columns))
offset += dimension
return X
def load_labels(filename):
"""load labels
filename: the name of the file containing data
return -- a row vector containing labels
"""
g_file = gzip.GzipFile(filename)
data = g_file.read()
magic, num = struct.unpack('>ii', data[:8])
d = np.frombuffer(data,dtype=np.uint8, count=num, offset=8)
return d
def load_data(foldername):
"""加载MINST数据集
foldername: the name of the folder containing datasets
return -- train_X训练数据集, train_y训练数据集对应的标签,
test_X测试数据集, test_y测试数据集对应的标签
"""
# filenames of datasets
train_X_name = "train-images-idx3-ubyte.gz"
train_y_name = "train-labels-idx1-ubyte.gz"
test_X_name = "t10k-images-idx3-ubyte.gz"
test_y_name = "t10k-labels-idx1-ubyte.gz"
train_X = load_images(os.path.join(foldername, train_X_name))
train_y = load_labels(os.path.join(foldername,train_y_name))
test_X = load_images(os.path.join(foldername, test_X_name))
test_y = load_labels(os.path.join(foldername, test_y_name))
return train_X, train_y, test_X, test_y
调用tensorflow2.0数据集处理API,进行图片预处理、图像洗牌、分批次等。
def process_image(image, label):
""" 图片预处理 """
# m = image.shape[0] * image.shape[1]
# image = tf.reshape(image, (m,)) # 全连接网络输入(768,); 2D卷积网络不需要这个转换
label = tf.one_hot(label, depth=10)
return image, label
def get_dataset(X, Y, batch_size=64):
ds = tf.data.Dataset.from_tensor_slices((X, Y))
ds = ds.map(process_image)
ds = ds.shuffle(buffer_size=1024)
ds = ds.batch(batch_size)
return ds
测试一下读取一个batch的数据:
if __name__ == "__main__":
# 读取数据集
train_X, train_y, test_X, test_y = load_data("./data/MNIST")
train_dataset = get_dataset(train_X, train_y, batch_size=64)
test_dataset = get_dataset(test_X, test_y, batch_size=64)
# 打印查看
for nbatch, (x, labels) in enumerate(train_dataset):
print("train x:", x.shape)
print("train labels:", labels.shape)
break
for nbatch, (x, labels) in enumerate(test_dataset):
print("test x:", x.shape)
print("test labels:", labels.shape)
break
最近在积攒粉丝500,大家帮帮忙,动动小手指关注、点赞、收藏…