Tensorflow中关于FixedLengthRecordReader()的理解

 在官方案例 CIFAR-10中有一个关于数据输入的程序 cifar10_input.py,

函数read_cifar10()的作用就是读取file_name文件队列中的文件,刚开始我的疑惑是整个函数好像只能读取一张图片(为什么可以读取多张图片形成batch,请参考),且读取位置是否总会从文件开头读取。

def read_cifar10(filename_queue):
'''
@param  filename_queue  要读取的文件名队列
@return 某个对象,具有以下字段:
        height: 结果中的行数 (32)
        width:  结果中的列数 (32)
        depth:  结果中颜色通道数(3)
        key:    一个描述当前抽样数据的文件名和记录数的标量字符串
        label:  一个 int32类型的标签,取值范围 0..9.
        uint8image: 一个[height, width, depth]维度的图像数据
'''
   
  # 建立一个空类,方便数据的结构化存储
  class CIFAR10Record(object):
    pass
  result = CIFAR10Record()

  # Dimensions of the images in the CIFAR-10 dataset.
  # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
  # 设置输入样本的格式.
  label_bytes = 1  # 2 for CIFAR-100
  result.height = 32   #图片高度
  result.width = 32    #图片宽度
  result.depth = 3     #通道数
  
  image_bytes = result.height * result.width * result.depth
  # Every record consists of a label followed by the image, with a
  # fixed number of bytes for each.
  record_bytes = label_bytes + image_bytes

  #读取固定长度字节数信息(针对bin文件使用FixedLengthRecordReader读取比较合适)
  reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
  result.key, value = reader.read(filename_queue)

  # 字符串转为张量数据decode_raw
  record_bytes = tf.decode_raw(value, tf.uint8)

  # tf.stride_slice(data, begin, end):从张量中提取数据段,并用cast进行数据类型转换
  result.label = tf.cast(
      tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)

  # The remaining bytes after the label represent the image, which we reshape
  # from [depth * height * width] to [depth, height, width].
  depth_major = tf.reshape(
      tf.strided_slice(record_bytes, [label_bytes],
                       [label_bytes + image_bytes]),
      [result.depth, result.height, result.width])
  # Convert from [depth, height, width] to [height, width, depth].
  result.uint8image = tf.transpose(depth_major, [1, 2, 0])

  return result

下面测试了tf.FixedLengthRecordReader是读取固定长度字节数信息(针对bin文件使用FixedLengthRecordReader读取比较合适),结果表明下次调用时会接着上次读取的位置继续读取文件,而不会从头开始读取。

import tensorflow as tf

filenames = ['D:/Tensorflow/test/txt1.txt']
filename_queue = tf.train.string_input_producer(filenames)

reader = tf.FixedLengthRecordReader(record_bytes=4)

key, value = reader.read(filename_queue)
b = value
sess = tf.InteractiveSession()
tf.train.start_queue_runners(sess=sess)

print(sess.run(b))
print('\n')
print(sess.run(b))

文本信息:

Tensorflow中关于FixedLengthRecordReader()的理解_第1张图片

输出结果:

Tensorflow中关于FixedLengthRecordReader()的理解_第2张图片


关于tensorflow数据读取操作详解


你可能感兴趣的:(Tensoflow知识点)