TensorFlow. ——tf.data.Dataset读取数据代码
- 下面是tf.data.Dataset读取通用数据的基础代码。
import tensorflow as tf
class TfDataDataset(object):
"""tf.data.Dataset读取数据方法"""
def __init__(self, im_size):
self._im_size = im_size
def train(self, data, batch_size):
iterator, batch = self.get_batch(data, batch_size)
with tf.Session as sess:
sess.run([iterator.initializer])
img1, img2, label1, label2 = sess.run(batch)
def get_batch(self, data, batch_size, num_epochs):
def read_image(filename):
image_value = tf.read_file(filename)
img = tf.image.decode_jpeg(image_value)
image_resize = tf.image.resize_images(img, [self._im_size, self._im_size])
image_resize.set_shape([self._im_size, self._im_size, 3])
image_resize = tf.cast(image_resize, dtype=tf.float32) * (1. / 255) * 2 - 1
return image_resize
def preprocess(img1, img2, label1, label2):
img1 = read_image(img1)
img2 = read_image(img2)
return img1, img2, label1, label2
img1, img2, label1, label2 = data
dataset = tf.data.Dataset.from_tensor_slices((img1, img2, label1, label2))
dataset = dataset.repeat(num_epochs)
dataset.shuffle(buffer_size=10000)
dataset = dataset.map(preprocess)
dataset = dataset.batch(batch_size=batch_size)
iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
batch = iterator.get_next()
return iterator, batch
- 下面是tf.data.Dataset读取tfrecord数据的基础代码,以及实现完全shuffle读取的代码。
class TfDataDatasetTfrecord(object):
"""tf.data.Dataset读取tfrecord数据方法"""
def __init__(self, im_size):
self._im_size = im_size
def train(self, data, batch_size):
iterator, batch = self.get_batch(data, batch_size)
with tf.Session as sess:
sess.run([iterator.initializer])
img1, img2, label1, label2 = sess.run(batch)
def get_batch(self, tfrecord_file, batch_size, num_epochs):
def _preprocess(example):
img = tf.decode_raw(example, tf.uint8)
img = tf.reshape(img, [self._im_size, self._im_size, 3])
img = tf.cast(img, tf.float32) * (1. / 255) * 2 - 1
return img
def _parse_function(example_proto):
name_to_features = {
"img1": tf.io.FixedLenFeature([], tf.string),
"img2": tf.io.FixedLenFeature([], tf.string),
"label1": tf.io.FixedLenFeature([], tf.int64),
"label2": tf.io.FixedLenFeature([], tf.int64)
}
example = tf.io.parse_single_example(example_proto, name_to_features)
img1 = _preprocess(example['img1'])
img2 = _preprocess(example['img2'])
label1 = example['label1']
label2 = example['label2']
return img1, img2, label1, label2
dataset = tf.compat.v1.data.TFRecordDataset(tfrecord_file)
dataset = dataset.repeat(num_epochs)
dataset.shuffle(buffer_size=10000)
dataset = dataset.map(_parse_function)
dataset = dataset.batch(batch_size=batch_size)
iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
batch = iterator.get_next()
return iterator, batch
def get_batch_shuffle(self, file_list, batch_size, num_epochs):
"""tfrecord数据过大,读取时完全打乱的方法,此时的file_list输入为多个tfrecord文件列表"""
def _preprocess(example):
img = tf.decode_raw(example, tf.uint8)
img = tf.reshape(img, [self._im_size, self._im_size, 3])
img = tf.cast(img, tf.float32) * (1. / 255) * 2 - 1
return img
def _parse_function(example_proto):
name_to_features = {
"img1": tf.io.FixedLenFeature([], tf.string),
"img2": tf.io.FixedLenFeature([], tf.string),
"label1": tf.io.FixedLenFeature([], tf.int64),
"label2": tf.io.FixedLenFeature([], tf.int64)
}
example = tf.io.parse_single_example(example_proto, name_to_features)
img1 = _preprocess(example['img1'])
img2 = _preprocess(example['img2'])
label1 = example['label1']
label2 = example['label2']
return img1, img2, label1, label2
files = tf.data.Dataset.list_files(file_list, shuffle=True)
dataset = files.interleave(map_func=tf.data.TFRecordDataset, cycle_length=1)
dataset = dataset.repeat(num_epochs)
dataset.shuffle(buffer_size=1000)
dataset = dataset.map(_parse_function)
dataset = dataset.batch(batch_size=batch_size)
iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
batch = iterator.get_next()
return iterator, batch