import tensorflow as tf
import tensorflow.keras.backend as K
import numpy as np
def _parse_function(example_proto):
# map_func: apply to each element of this dataset
features = {
'image': tf.FixedLenFeature([], tf.string),
'Young': tf.FixedLenFeature([1], tf.int64)
}
parsed_example = tf.parse_single_example(example_proto, features)
image = tf.decode_raw(parsed_example['image'], tf.uint8)
image = tf.cast(tf.reshape(image,[128, 128, 3]), tf.float32)
Young = (parsed_example["Young"] + 1) // 2
Young = tf.cast(tf.reshape(Young, [1]), tf.int64)
parsed_example['image'] = image
parsed_example['Young'] = Young
return parsed_example
if __name__ == "__main__":
filename = '/home/dddzz/桌面/celabaA_40attributes/data/train_cropped.tfrecords'
sess = K.get_session()
batch_size = 64
# get serialized example
serialized_example = tf.data.TFRecordDataset(filenames=[filename])
parsed_example = serialized_example.map(_parse_function)
shuffle_parsed_example = parsed_example.shuffle(buffer_size=batch_size * 5)
shuffle_parsed_example = shuffle_parsed_example.batch(batch_size, drop_remainder=True)
#print(type(shuffle_parsed_example))
#iterator = shuffle_parsed_example.make_one_shot_iterator()
#next_element = iterator.get_next()
iterator = shuffle_parsed_example.make_initializable_iterator()
next_batch = iterator.get_next()
sess = tf.InteractiveSession()
sess.run(iterator.initializer)
i = 1
while True:
try:
batch = sess.run([next_batch])
Young = sess.run([next_batch['Young']])
except tf.errors.OutOfRangeError:
print("End of dataset")
break
else:
print('======example %s======' %i)
print(np.shape(Young))
i += 1
参考:1. https://zhuanlan.zhihu.com/p/33223782
2. https://blog.csdn.net/Fenplan/article/details/90667045
整理一下:
import tensorflow as tf
import tensorflow.keras.backend as K
import numpy as np
def _parse_function(example_proto):
# map_func: apply to each element of this dataset
features = {
'image': tf.FixedLenFeature([], tf.string),
'Young': tf.FixedLenFeature([1], tf.int64)
}
parsed_example = tf.parse_single_example(example_proto, features)
image = tf.decode_raw(parsed_example['image'], tf.uint8)
image = tf.cast(tf.reshape(image,[128, 128, 3]), tf.float32)
Young = (parsed_example["Young"] + 1) // 2
Young = tf.cast(tf.reshape(Young, [1]), tf.int64)
parsed_example['image'] = image
parsed_example['Young'] = Young
return parsed_example
def get_data_batch(sess, filename, batch_size=64):
# get serialized example from tfrecords
serialized_example = tf.data.TFRecordDataset(filenames=[filename])
# get parsed example with .map() and shuffle
parsed_example = serialized_example.map(_parse_function)
shuffle_parsed_example = parsed_example.shuffle(buffer_size=batch_size * 5)
# generate batches, drop the tail
shuffle_parsed_example = shuffle_parsed_example.batch(batch_size, drop_remainder=True)
# run the file queue and get batches
iterator = shuffle_parsed_example.make_initializable_iterator()
next_batch = iterator.get_next()
sess.run(iterator.initializer)
return next_batch
if __name__ == "__main__":
filename = '/home/dddzz/桌面/celabaA_40attributes/data/train_cropped.tfrecords'
batch_size = 64
sess = tf.InteractiveSession()
train_batch = get_data_batch(sess, filename, batch_size)
print(train_batch)