Dataset读取tfrecord

    def tfrecord_pipeline(cls, tfrecord_file,  batch_size, prebatch,
                          epochs, shuffle=True):
        ''获取tfrecord配置文件''
        # tfrecord file should be a text file with absolute path of tfrecords
        if not os.path.isfile(tfrecord_file):
            raise ValueError('{} should be a text file'.format(tfrecord_file))
        with open(tfrecord_file) as f:
            record_files = [path.strip() for path in f]

        ''parser函数是自定义的数据解码函数''
        def parser(record):
            feature_map = {
                CATEGORICAL_FEATURE_NAME: tf.FixedLenFeature(
                   [prebatch * tf.app.flags.FLAGS.sparse_features], tf.int64),
                'label': tf.FixedLenFeature([prebatch], tf.int64),
                'numerical': tf.FixedLenFeature(
                   [prebatch * tf.app.flags.FLAGS.dense_features], tf.float32),
            }
            features = tf.parse_single_example(record, features=feature_map)
            return features

        dataset = tf.data.TFRecordDataset(filenames=record_files)
        if shuffle:
            dataset = dataset.shuffle(buffer_size=128)
        dataset = dataset.repeat(epochs).map(parser, num_parallel_calls=4) \
                         .batch(batch_size) \
                         .prefetch(buffer_size=512)
        return dataset.make_initializable_iterator()

 

你可能感兴趣的:(Dataset读取tfrecord)