Tensorflow直接读取二进制文件

Tensorflow直接读取二进制文件

Tensorflow可以直接读取记录为固定长度的bin文件,比如cifar-10,流程基本与读取csv文件一直,只有一些细微的差别。

import tensorflow as tf
import numpy as np

# 预定义图像数据信息
labelBytes = 1
witdthBytes = 32
heightBytes = 32
depthBytes = 3
imageBytes = witdthBytes*heightBytes*depthBytes
recordBytes = imageBytes+labelBytes

filename_queue = tf.train.string_input_producer(["./data/train.bin"])
reader = tf.FixedLengthRecordReader(record_bytes=recordBytes) # 按固定长度读取二进制文件
key,value = reader.read(filename_queue)

bytes = tf.decode_raw(value,out_type=tf.uint8) # 解码为uint8,0-255 8位3通道图像
label = tf.cast(tf.strided_slice(bytes,[0],[labelBytes]),tf.int32) # 分割label并转化为int32

originalImg  = tf.reshape(tf.strided_slice(bytes,[labelBytes],[labelBytes+imageBytes]),[depthBytes,heightBytes,witdthBytes])
# 分割图像,此时按照数据组织形式深度在前
img = tf.transpose(originalImg,[1,2,0]) # 调整轴的顺序,深度在后

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    for i in range(100):
        imgArr = sess.run(img)
        print (imgArr.shape)

    coord.request_stop()
    coord.join(threads)

需要注意的地方:

  1. 阅读区使用FixedLengthRecordReader读取固定长度记录,decode_raw解析二进制位无符号8位数。
  2. strided_slice分割函数,起点、终点和步长都是要引入列表形式,实际上在多维情况下,指示的是坐标。
  3. 最终数据要按照数据的组织形式进行合理变换。

你可能感兴趣的:(机器学习)