注意:这是一个完整的项目,建议您按照完整的博客顺序阅读。
前言
一、MNIST数据集的介绍
二、使用MNIST数据集
1、读取MNIST数据集
2、MNIST数据集的常见操作
三、可视化MNIST图片
1、随机抽取9张图片数据
2、可视化图片数据
3、可视化效果
提示
MNIST手写数字识别是深度学习过程中的经典项目,很可能是很多深度学习爱好者的第一个入门项目。
该项目使用的是MNIST数据集,该数据集共有7万张图片,其中6万张用于训练神经网络,1万张用于测试神经网络。
其中每张图片是一个28*28像素点的0~9的手写数字的黑底白字的灰度图片。 黑底用0表示,白字用0~1之间的浮点数表示,越接近1,颜色越白,如下图:
由于该数据集过于有名,目前全网最常见的该数据集使用的有4种格式:npz版本、gz版本、pkl.gz版本、mnist.zip版本。
不同的版本格式可能加载方式存在不同,但是其内容基本一致。
其中pkl.gz版本的下载地址为:http://www.deeplearning.net/tutorial/gettingstarted.html
本项目使用的是gz版本,该版本可以使用TensorFlow自带的数据集读取的API,其下载地址为:http://yann.lecun.com/exdb/mnist。
下载后整个数据集会被划分为三个部分:
训练集: (numpy.array-float32(50000,784), numpy.array-int64(50000,1))
验证集: (numpy.array-float32(10000,784), numpy.array-int64(10000,1))
测试集: (numpy.array-float32(10000,784), numpy.array-int64(10000,1))
注意:验证集一般用来在训练时评估训练效果,可以看做小量的测试集。
# 导入TensorFlow自带的读取数据集的API
from tensorflow.examples.tutorials.mnist import input_data
# 加载MNIST数据集
mnist = input_data.read_data_sets('./mnist/', one_hot=False)
如果one_hot=False,则标签直接为数字;
如果one_hot=True,则标签为One_hot编码后的数组。
# 获取各个数据集的样本数量
n1 = mnist.train.num_examples
n2 = mnist.validation.num_examples
n3 = mnist.test.num_examples
# 取出32个样本
xs, ys = mnist.train.next_batch(32)
# 获取训练集的图片数据和标签
_images, _labels = mnist.train.images, mnist.train.labels
# 获取训练集的图片数据和标签
_images, _labels = mnist.train.images, mnist.train.labels
# 随机抽取9个数据样本
random_indices = random.sample(range(len(_images)), min(len(_images), 9))
images, labels = zip(*[(_images[i], _labels[i]) for i in random_indices])
# 可视化样本
plot_images(images=images, cls_true=labels,img_size=28,num_channels=1)
为了检验一些我们的图片数据,我们这里设计了一个可视化图像数据的程序,其实就是一个方法,这个方法的思路非常简单:
def plot_images(images, cls_true, img_size=28, cls_pred=None, num_channels=1):
# 检测图像是否存在
if len(images) <= 0 or len(images)>9:
print("没有图像来展示或者图像个数过多")
return
# 创造一个3行3列的画布
fig, axes = plt.subplots(3, 3)
fig.subplots_adjust(hspace=0.6, wspace=0.6)
fig.canvas.set_window_title('Random Images show') # 设置字体大小与格式
for i, ax in enumerate(axes.flat):
# 显示图片
if len(images) < i + 1:
break
ax.imshow(images[i].reshape(img_size, img_size, num_channels))
# 展示图像的语义标签和实际预测标签
if cls_pred is None:
xlabel = "True: {0}".format(cls_true[i])
else:
xlabel = "True: {0}, Pred: {1}".format(cls_true[i], cls_pred[i])
# 设置每张图的标签为其xlabel.
ax.set_xlabel(xlabel)
# 设置图片刻度
ax.set_xticks([0, img_size])
ax.set_yticks([0, img_size])
plt.show()
注意这里我们的MNIST数据为28*28的黑白图片,所以我们的参数img_size=28, num_channels=1。
最终输出:
如果本项目对您的学习有帮助,欢迎点赞支持!