tensorflow读写二进制文件

tensorflow建议用Dataset取代queue读取数据,Dataset读取数据有占用内存小、方便对数据预处理等优点,但是读写文件需要注意格式问题,下面实现一个读写二进制文件的例子:
首先创建numpy数组存成二进制文件

M = 200; N = 5; np.random.seed(231)
a = np.random.rand(M, N)
file_a = 'test_a.txt'
with open(file_a, 'wb') as f:
    f.write(a)

然后定义解析二进制文件函数

def decode_float64(x):
    return tf.decode_raw(x, tf.float64)

最后读取文件

dataset = tf.data.FixedLengthRecordDataset(file_a, record_bytes=8*N)
dataset = dataset.map(decode_float64)
element = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
    for i in range(5):
        print(sess.run(element))

输出结果为


输出结果

还可以通过batch/map/shuffle/zip等函数对Dataset类进行修改

你可能感兴趣的:(tensorflow读写二进制文件)