from tensorflow.examples.tutorials.mnist import input_data
import scipy.misc
import os
导入相关库文件。
1、input_data主要是包含了mnist数据集中的大量数据,用于导出官方相关的测试所需数据。
2、scipy.misc主要包含一些图像的输入输出函数,例如下文中用到的toimage函数用于显示指定路径下的图片。
3、os主要包含了一些对于文件及其路径的相关函数,下文中运用到的os.path.exists()就是用于判断该路径的文件是否存在。
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
导入相关的实验数据,从TensorFlow官方提供的MNIST数据集中下载到本地,这里容易遇到坑,例如导不进四个压缩文件,当出现这种情况时可以去官网手动下载四个压缩文件并且直接剪切至项目相关文件夹处。此时可能会出现warning,这个时候就不要管了,warning是因为版本的问题,但是相关数据是可以进行操作的,如果有大佬有更好的解决方法可以给小编留言或者是评论,万分感谢。
save_dir = 'MNIST_data/raw/'
if os.path.exists(save_dir) is False:
os.mkdir(save_dir)
设置保存图片的路径,os.path.exists(path)这个函数是用来判断该文件是否存在的,如果不存在就进行os.mkdir()来创建一个文件夹用于存储相关的图片集合
for i in range(20):
image_array = mnist.train.images[i, :]
image_array = image_array.reshape(28,28)
filename = save_dir + 'mnist_train_%d.jpg' % i
scipy.misc.toimage(image_array,cmin=0.0,cmax=1.0).save(filename)
这段循环是用来存储20张图片的,原来的数据集中图片信息是用一维数组来表示的,现在利用reshape()函数来将一维图像数组转化成二维数组,然后将其存储到相关的文件夹下面。元素的取值范围是[0,1]。