6月15日,代码链接更新...
在这一篇文章里,我们将继续上一篇文章的工作,并且在上一篇文章的前提下加入数据集的制作,最终我们将完成这个全连接神经网络的小栗子.
先说说我们上一篇文章我们的自制数据集的一切缺点,第一,数据集过于分散,在一个文件夹里读取难免导致内存利用率低,而我们将会使用TensorFlow的tfrecords()函数来讲图片和标签制作成这种二进制文件,这样我们的内存利用率会增加不少.
将数据保存为tfrecords文件可以视为这样一个流程:
提取features -> 保存为Example结构对象 -> TFRecordWriter写入文件
而如果我们要存储训练数据的时候,我们会使用tf.train.Example()来去存储,并且训练数据的特征用键值对来表示.
现在让我们看看代码:
1:首先,我们先加入文件的路径.这些文件是从mnist数据集中随机找了一些,当然,我们也可以换成别的.
定义变量,加入引用的路径:
2:制作数据集.
首先我们新建一个writer.
然后使用for循环来去遍历我们文件中的每一张图和每一张图的标签
最后我们把每张图片及其标签封装到example中
最终将其序列化后即可完成.
这样我们的数据集就已经制作完成了.
3:读取tfrecords()文件
具体代码:
先从read_tfRecord函数说起:
在这个函数中,我们主要使用的是:
filename_queue = tf.train.string_input_producer([tfRecord_path])
在使用这个函数后,这个函数会产生一个先进先出的队列,文本阅读器会用它来读取数据.
而这个队列,我们在get_tfrecord中使用到:
具体的参数,在此说明下:
tf.train.string_input_producer( string_tensor, #存储图像和标签信息的 TFRecord 文件名列表
num_epochs=None, #循环读取的轮数(可选)
shuffle=True,#布尔值(可选),如果为 True,则在每轮随机打乱读取顺序
seed=None,#随机读取时设置的种子(可选)
capacity=32, #设置队列容量
shared_name=None, #如果设置,该队列将在多个会话中以给定名 称共享。所
name=None,#操作的名称(可选)
cancel_op=None)#取消队列(None)
接着说reader_tfrecord中:
reader = tf.TFRecordReader() #新建一个 reader
这个操作是把读出的样本在serialized_example中进行解析,标签和图片的键名应该和制作 tfrecords 的键名相同.该函数可以将 tf.train.Example 协议内存块(protocol buffer)解析为张量。
img = tf.decode_raw(features['img_raw'], tf.uint8) #将 img_raw 字符串转换为 8 位无符号整型 img.set_shape([784]) #将形状变为一行 784 列
img = tf.cast(img, tf.float32) * (1. / 255) #变成 0 到 1 之间的浮点数
label = tf.cast(features['label'], tf.float32)#把标签列表变为浮点数
return image,label #返回图片和标签(跳回到 get_tfrecord)
回到get_tfrecord中:
tf.train.shuffle_batch(),随机读取一个batch的数据
这个函数值得说说,完整的格式如下:
tf.train.shuffle_batch(
tensors, #待乱序处理的列表中的样本(图像和标签)
batch_size, #从队列中提取的新批量大小
capacity, #队列中元素的最大数量
min_after_dequeue, #出队后队列中的最小数量元素,用于确保元素 的混合级别
num_threads=1, #排列 tensors 的线程数
seed=None, #用于队列内的随机洗牌
enqueue_many=False, #tensor 中的每个张量是否是一个例子
shapes=None, #每个示例的形状
allow_smaller_final_batch=False, #如果为 True,则在 队列中剩余数量不足时允许最终批次更小。
shared_name=None, #如果设置,该队列将在多个会话中以给定名称 共享。
name=None #操作的名称)
最后返回的图片和标签为随机抽取的 batch_size 组
而在下一篇文章中,我们将在反向传播文件中修改图片标签的获取接口,并且利用多线程来去提高图片和标签的批处理获取效率.
代码地址:链接:https://pan.baidu.com/s/1AS6X2hpwtNdAjW7CMz7Rxw 密码:8vih