TensorFlow. ——tf.data.Dataset读取数据代码

TensorFlow. ——tf.data.Dataset读取数据代码

  • 下面是tf.data.Dataset读取通用数据的基础代码。
import tensorflow as tf


class TfDataDataset(object):
    """tf.data.Dataset读取数据方法"""

    def __init__(self, im_size):
        self._im_size = im_size

    def train(self, data, batch_size):
        iterator, batch = self.get_batch(data, batch_size)
        with tf.Session as sess:
            sess.run([iterator.initializer])
            img1, img2, label1, label2 = sess.run(batch)

    def get_batch(self, data, batch_size, num_epochs):
        def read_image(filename):
            image_value = tf.read_file(filename)
            img = tf.image.decode_jpeg(image_value)
            image_resize = tf.image.resize_images(img, [self._im_size, self._im_size])
            image_resize.set_shape([self._im_size, self._im_size, 3])
            image_resize = tf.cast(image_resize, dtype=tf.float32) * (1. / 255) * 2 - 1
            return image_resize

        def preprocess(img1, img2, label1, label2):
            img1 = read_image(img1)
            img2 = read_image(img2)
            return img1, img2, label1, label2

        img1, img2, label1, label2 = data
        dataset = tf.data.Dataset.from_tensor_slices((img1, img2, label1, label2))
        dataset = dataset.repeat(num_epochs)
        dataset.shuffle(buffer_size=10000)
        dataset = dataset.map(preprocess)
        dataset = dataset.batch(batch_size=batch_size)
        iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
        batch = iterator.get_next()
        return iterator, batch
  • 下面是tf.data.Dataset读取tfrecord数据的基础代码,以及实现完全shuffle读取的代码。
class TfDataDatasetTfrecord(object):
    """tf.data.Dataset读取tfrecord数据方法"""

    def __init__(self, im_size):
        self._im_size = im_size

    def train(self, data, batch_size):
        iterator, batch = self.get_batch(data, batch_size)
        with tf.Session as sess:
            sess.run([iterator.initializer])
            img1, img2, label1, label2 = sess.run(batch)

    def get_batch(self, tfrecord_file, batch_size, num_epochs):
        def _preprocess(example):
            img = tf.decode_raw(example, tf.uint8)
            img = tf.reshape(img, [self._im_size, self._im_size, 3])
            img = tf.cast(img, tf.float32) * (1. / 255) * 2 - 1
            return img

        def _parse_function(example_proto):
            name_to_features = {
                "img1": tf.io.FixedLenFeature([], tf.string),
                "img2": tf.io.FixedLenFeature([], tf.string),
                "label1": tf.io.FixedLenFeature([], tf.int64),
                "label2": tf.io.FixedLenFeature([], tf.int64)
            }
            example = tf.io.parse_single_example(example_proto, name_to_features)

            img1 = _preprocess(example['img1'])
            img2 = _preprocess(example['img2'])
            label1 = example['label1']
            label2 = example['label2']
            return img1, img2, label1, label2

        dataset = tf.compat.v1.data.TFRecordDataset(tfrecord_file)
        dataset = dataset.repeat(num_epochs)
        dataset.shuffle(buffer_size=10000)
        dataset = dataset.map(_parse_function)
        dataset = dataset.batch(batch_size=batch_size)
        iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
        batch = iterator.get_next()
        return iterator, batch

    def get_batch_shuffle(self, file_list, batch_size, num_epochs):
        """tfrecord数据过大,读取时完全打乱的方法,此时的file_list输入为多个tfrecord文件列表"""

        def _preprocess(example):
            img = tf.decode_raw(example, tf.uint8)
            img = tf.reshape(img, [self._im_size, self._im_size, 3])
            img = tf.cast(img, tf.float32) * (1. / 255) * 2 - 1
            return img

        def _parse_function(example_proto):
            name_to_features = {
                "img1": tf.io.FixedLenFeature([], tf.string),
                "img2": tf.io.FixedLenFeature([], tf.string),
                "label1": tf.io.FixedLenFeature([], tf.int64),
                "label2": tf.io.FixedLenFeature([], tf.int64)
            }
            example = tf.io.parse_single_example(example_proto, name_to_features)

            img1 = _preprocess(example['img1'])
            img2 = _preprocess(example['img2'])
            label1 = example['label1']
            label2 = example['label2']
            return img1, img2, label1, label2

        files = tf.data.Dataset.list_files(file_list, shuffle=True)
        dataset = files.interleave(map_func=tf.data.TFRecordDataset, cycle_length=1)
        dataset = dataset.repeat(num_epochs)
        dataset.shuffle(buffer_size=1000)
        dataset = dataset.map(_parse_function)
        dataset = dataset.batch(batch_size=batch_size)
        iterator = tf.compat.v1.data.make_initializable_iterator(dataset)
        batch = iterator.get_next()
        return iterator, batch

你可能感兴趣的:(tensorflow,tensorflow,深度学习,dataset)