tfrecord读取过程简介

通常做法是使用Tensorflow的Dataset来读取我们的tfRecord,但是老的版本也有通过TFRecordReader进行解析,这里我们先介绍使用Dataset方式读取

  • 加载TFRecord文件
  • 通过parse_fn方法对每条样本机型解析
  • 重复N epochs
  • batch
def parse_fn(example_proto):
    features = {
     "state": tf.FixedLenFeature((), tf.string),
                "action": tf.FixedLenFeature((), tf.int64),
                "reward": tf.FixedLenFeature((), tf.int64)}
    parsed_features = tf.parse_single_example(example_proto, features)
    return tf.decode_raw(parsed_features['state'], tf.float32), parsed_features['action'], parsed_features['reward']


with tf.Session() as sess:
    dataset = tf.data.TFRecordDataset(output_file)  # 加载TFRecord文件
    dataset = dataset.map(parse_fn)  # 解析data到Tensor
    dataset = dataset.repeat(1)  # 重复N epochs
    dataset = dataset.batch(3)  # batch size

    iterator = dataset.make_one_shot_iterator()
    next_data = iterator.get_next()

    while True:
        try:
            state, action, reward = sess.run(next_data)
            print(state)
            print(action)
            print(reward)
        except tf.errors.OutOfRangeError:
            break

遍历结果:
tfrecord读取过程简介_第1张图片

解析tfrecord的2种方式

for example in tf.io.tf_record_iterator(output_file):
    print("first method")
    print(tf.train.Example.FromString(example))
    # 或者用下面的方法
    print("second method")
    from google.protobuf.json_format import MessageToJson
    jsonMessage = MessageToJson(tf.train.Example.FromString(example))
    print(jsonMessage)

解析结果:

first method
features {
     
  feature {
     
    key: "action"
    value {
     
      int64_list {
     
        value: 1
      }
    }
  }
  feature {
     
    key: "reward"
    value {
     
      int64_list {
     
        value: 90
      }
    }
  }
  feature {
     
    key: "state"
    value {
     
      bytes_list {
     
        value: "\037\205\277B\341zD@\217\302\247A\270\036\303B\205\353\237Aff\230A33\031B\315\314\300A\nW\202B\244p/B"
      }
    }
  }
}

second method
{
     
  "features": {
     
    "feature": {
     
      "action": {
     
        "int64List": {
     
          "value": [
            "1"
          ]
        }
      },
      "state": {
     
        "bytesList": {
     
          "value": [
            "H4W/QuF6RECPwqdBuB7DQoXrn0FmZphBMzMZQs3MwEEKV4JCpHAvQg=="
          ]
        }
      },
      "reward": {
     
        "int64List": {
     
          "value": [
            "90"
          ]
        }
      }
    }
  }
}

完整代码:

"""
本程序演示了如何保存numpy array为TFRecords文件,并将其读取出来。
"""
import random

import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

def save_tfrecords(state_data, action_data, reward_data, dest_file):
    """
    保存numpy array到TFRecord文件中。
    这里输入了三个不同的numpy array来做演示,它们含有不同类型的元素。
    Args:
        state_data: 要保存到TFRecord文件的第1个numpy array,每一个 state_data[i] 是一个 numpy.ndarray(数组里的每个元素又是一个浮点
                    数),因此不能用 Int64List 或 FloatList 来存储,只能用 BytesList。
        action_data: 要保存到TFRecord文件的第2个numpy array,每一个 action_data[i] 是一个整数,使用 Int64List 来存储。
        reward_data: 要保存到TFRecord文件的第3个numpy array,每一个 reward_data[i] 是一个整数,使用 Int64List 来存储。
        dest_file: 输出文件的路径。
    Returns:
        不返回任何值
    """
    with tf.io.TFRecordWriter(dest_file) as writer:
        for i in range(len(state_data)):
            features = tf.train.Features(
                feature={
     
                    "state": tf.train.Feature(
                        bytes_list=tf.train.BytesList(value=[state_data[i].astype(np.float32).tobytes()])),

                    # "state": tf.train.Feature(
                    #     float_list=tf.train.FloatList(value=state_data[i].astype(np.float))),
                    # "action": tf.train.Feature(
                    #     int64_list=tf.train.Int64List(value=[action_data[i]])),
                    # "reward": tf.train.Feature(
                    #     int64_list=tf.train.Int64List(value=[reward_data[i]]))

                    "action": tf.train.Feature(
                        int64_list=tf.train.Int64List(value=action_data[i].astype(np.int))),
                    "reward": tf.train.Feature(
                        int64_list=tf.train.Int64List(value=reward_data[i].astype(np.int)))
                }
            )
            tf_example = tf.train.Example(features=features)
            serialized = tf_example.SerializeToString()
            writer.write(serialized)


def parse_fn(example_proto):
    features = {
     "state": tf.FixedLenFeature((), tf.string),
                "action": tf.FixedLenFeature((), tf.int64),
                "reward": tf.FixedLenFeature((), tf.int64)}
    parsed_features = tf.parse_single_example(example_proto, features)
    return tf.decode_raw(parsed_features['state'], tf.float32), parsed_features['action'], parsed_features['reward']


if __name__ == '__main__':
    buffer_s, buffer_a, buffer_r = [], [], []

    # 随机生成一些数据
    for i in range(30):
        state = [round(random.random() * 100, 2) for _ in range(0, 10)]  # 一个数组,里面有10个数,每个都是一个浮点数
        action = random.randrange(0, 2)  # 一个数,值为 0 或 1
        reward = random.randrange(0, 100)  # 一个数,值域 [0, 100)
        # 把生成的数分别添加到3个list中
        buffer_s.append(state)
        buffer_a.append(action)
        buffer_r.append(reward)

        # 查看生成的数据
    print(buffer_s)
    print(buffer_a)
    print(buffer_r)

    # 在水平方向把各个list堆叠起来,堆叠的结果:得到3个矩阵
    s_stacked = np.vstack(buffer_s)
    a_stacked = np.vstack(buffer_a)
    r_stacked = np.vstack(buffer_r)

    print(s_stacked.shape)  # (3, 10)
    print(a_stacked.shape)  # (3, 1)
    print(r_stacked.shape)  # (3, 1)


    print(s_stacked)
    print(a_stacked)
    print(r_stacked)


    print("data generate sucess!")

    # 写入TFRecord文件
    output_file = './data.tfrecord'  # 输出文件的路径
    save_tfrecords(s_stacked, a_stacked, r_stacked, output_file)

    # 读取TFRecord文件并打印出其内容
    for example in tf.io.tf_record_iterator(output_file):
        print("first method")
        print(tf.train.Example.FromString(example))
        # 或者用下面的方法
        print("second method")
        from google.protobuf.json_format import MessageToJson
        jsonMessage = MessageToJson(tf.train.Example.FromString(example))
        print(jsonMessage)

    # 读取TFRecord文件并还原成numpy array,再打印出来
    with tf.Session() as sess:
        dataset = tf.data.TFRecordDataset(output_file)  # 加载TFRecord文件
        dataset = dataset.map(parse_fn)  # 解析data到Tensor
        dataset = dataset.repeat(1)  # 重复N epochs
        dataset = dataset.batch(3)  # batch size

        iterator = dataset.make_one_shot_iterator()
        next_data = iterator.get_next()

        while True:
            try:
                print("get next")
                state, action, reward = sess.run(next_data)
                print(state)
                print(action)
                print(reward)
            except tf.errors.OutOfRangeError:
                break

TFRecordReader方式

  • tf.train.string.input_producer 读取序列化后的的TFRecord记录,生成一个QueueRunner,它包含一个FIFOQueue队列
  • 通过tf.TFRecordReader() 依据定义的模式,进行反序列化parse,可以附带一些转换操作
  • batch,通过tf.train.shuffle_batch生成了RandomShuffleQueue
  • 通过 tf.train.Coordinator() tf.train.start_queue_runners 载入数据训练

附录:
使用tensorflow中的Dataset来读取制作好的tfrecords文件

你可能感兴趣的:(tf,python,tensorflow,深度学习)