tfrecord 基本用法

    tfrecord格式是tensorflow官方推荐的数据格式,把数据、标签进行统一的存储

   tfrecord文件包含了tf.train.Example 协议缓冲区(protocol buffer,协议缓冲区包含了特征 Features), 能让tensorflow更好的利用内存。

    把某个文件夹的图片和标签存入同一个tfrecord文件,代码如下:

def write(input_file, output_file):
    writer = tf.python_io.TFRecordWriter(output_file) #定义writer,传入目标文件路径
    path = input_file
    file_names = [f for f in os.listdir(path) if f.endswith('.jpg')] #获取待存文件路径
    for file_name in file_names:
        img = cv2.imread(path + file_name)
        raw_img = img.tobytes() #需要把图片文件转化成bytes形式(二进制比特流)

        # 把数据合并成feature,注意这里的"value="后面一定要是一个"[]"形式的列表,否则读取的时候会出现can't parse的情况
        features = tf.train.Features(feature={'img_name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[file_name])),
                   'raw_img': tf.train.Feature(bytes_list=tf.train.BytesList(value=[raw_img]))})
        #把features存入example
        example = tf.train.Example(features=features)
        #example序列化,并写入文件
        writer.write(example.SerializeToString())
    writer.close()


input_file = 'samples/'
output_file = 'samples.tfrecords'
write(input_file, output_file)
print 'Write tfrecords: %s done' %output_file

    基本步骤:

  1. 读取待存文件内容,转化为bytes形式
  2. 数据合并成tf.train.Features(类似dict形式)
  3. 把features存入一个tf.train.Example
  4. 把example序列化,并写入文件

    写入文件的实际上是若干个example。

    其中,tf.train.Features的bytes_list支持的类型有三种:tf.train.ByteList、tf.train.FloatList、tf.train.Int64List,形式如下:

tf.train.Features(
                feature={
                'label': tf.train.Feature(int64_list = tf.train.Int64List(value=[i])),     
                #'label': tf.train.Feature(float_list = tf.train.FloatList(value=[i])),
                'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))
                }))

    读取文件,代码如下:

def read_and_decode(file_name):
    filename_queue = tf.train.string_input_producer([file_name])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
                                       features={"img_name": tf.FixedLenFeature([], tf.string),
                                                 "raw_img": tf.FixedLenFeature([], tf.string)})
    img_name = features["img_name"]
    image = tf.decode_raw(features['raw_img'], tf.uint8)
    image = tf.reshape(image, [256, 256, 3])
    return img_name, image


path = 'samples.tfrecords'
with tf.Session() as sess:
    img_name, img = read_and_decode(path)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    try:
        for i in range(n):
            name, image = sess.run([img_name, img])
            print name
            plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    except tf.errors.OutOfRangeError:
        print 'Done training -- epoch limit reached'
    finally:
        coord.request_stop()

    coord.join(threads)

    上述读取过程结合了文件队列,基本过程如下:

  1. 定义文件名队列ft.train.string_input_producer
  2. 定义tf.TFRecordReader
  3. 读取序列化的example
  4. 调用tf.tf.parse_single_example解析example,得到features
  5. 从features获取具体的数据,如果是图像,进行解码和reshape(还可以进行相关的预处理)

    上述1~5步是读取一个example的“graph”,在实际使用时,先定义好graph,然后start_queue_runners(注意先后顺序,否则进程将阻塞),再根据需要,循环读取数据。

你可能感兴趣的:(机器学习)