tensorflow数据加载之DataSet

话不多说,干就完了。


在上一篇博文中简单介绍了一下TFRecord数据格式的生成和加载,本篇介绍另一种数据加载方式DataSet。

不管何种类型的数据加载方式都遵循一下几个步骤:

  1. 遍历原始图像数据集中的所有图片
  2. 读取图片和对应的类别标签
  3. 将读取到的图片向量输入模型用于训练

下面就讲一下DataSet加载图像数据的使用方法:

前提:本文使用的测试数据集只有6张图片,猫狗各三张,数据集目录结构如下:

tensorflow数据加载之DataSet_第1张图片

1)、读取原始数据信息,并保存到txt文件中,保存格式为:

图像地址         图像标签(从0开始,依次递增,方便后续进行one-hot编码)

import os

current_path = os.path.dirname(os.path.abspath(__file__))
datasets_path = os.path.join(current_path, "dataset")
filename = os.path.join(current_path, "dataset_image_list.txt")

if os.path.exists(filename) is False:
    with open(filename, "w") as f_obj:
        for cls_index, cls_name in enumerate(os.listdir(datasets_path)):
            print(cls_name)
            if os.path.isdir(os.path.join(datasets_path, cls_name)):
                print("#" * 40, cls_name, "#" * 40)
                for img_index, img_name in enumerate(os.listdir(os.path.join(datasets_path, cls_name))):
                    if os.path.isfile(os.path.join(datasets_path, cls_name, img_name)):
                        img_path = os.path.join(datasets_path, cls_name, img_name)
                        f_obj.write(img_path + "\t" + str(cls_index) + "\n")
else:
    print("file exists")

将图像数据信息先读取到txt文件中的好处还有一个是,后面对图像进行处理,不管是生成TFRecord还是DataSet的过程中,都方便对数据集进行重排shuffle,不然的话就是连续的读取相同类型标签的图像,在网络训练时还需要再次进行重排shuffle。

2)、加载第一步生成的txt文件,将其中的图像信息保存到TFRecord文件中

import os
import cv2
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np

current_path = os.path.dirname(os.path.abspath(__file__))
datasets_path = os.path.join(current_path, "dataset_image_list.txt")
train_tfreocrd_filename = os.path.join(current_path, "tfrecord_files", "cat_and_dog_train.tfrecords")
validation_tfreocrd_filename = os.path.join(current_path, "tfrecord_files", "cat_and_dog_validation.tfrecords")
image_size = 224
image_channel = 3


def generate_tfrecord():
    if os.path.exists(train_tfreocrd_filename):
        return

    def _int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

    train_writer = tf.python_io.TFRecordWriter(path=train_tfreocrd_filename)
    validation_writer = tf.python_io.TFRecordWriter(path=validation_tfreocrd_filename)

    with open(datasets_path, "r") as f_obj:
        image_list = f_obj.readlines()
        image_list_len = len(image_list)
        print(image_list_len)
        permutation = np.random.permutation(image_list_len)
        print(permutation)

        img_list = []
        for i in permutation:
            img_list.append(image_list[i])

        for img_index, img_info in enumerate(img_list):
            img_path, img_class = img_info.split()
            img_class = int(img_class)
            print(img_path, img_class)

            if os.path.isfile(img_path):
                img = cv2.imread(filename=img_path)

                if img.ndim == image_channel:
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    img = cv2.resize(img, (image_size, image_size), cv2.INTER_AREA)
                    img_pixles = img.shape[0] * img.shape[1] * img.shape[2]

                    # plt.imshow(img)
                    # plt.show()
                else:
                    continue
                image_raw = img.tostring()
                # print(image_raw)

                example = tf.train.Example(
                    features=tf.train.Features(
                        feature={
                            "pixels": _int64_feature(img_pixles),
                            "label": _int64_feature(img_class),
                            "image_raw": _bytes_feature(image_raw)
                        }
                    )
                )

                if img_index % 2 == 0:
                    validation_writer.write(example.SerializeToString())
                else:
                    train_writer.write(example.SerializeToString())
    train_writer.close()
    validation_writer.close()


def read_from_tfrecord(sess):
    reader = tf.TFRecordReader()

    filename_queue = tf.train.string_input_producer(
        string_tensor=[train_tfreocrd_filename, validation_tfreocrd_filename])

    _, serialized_example = reader.read(filename_queue)

    features = tf.parse_single_example(
        serialized_example,
        features={
            "pixels": tf.FixedLenFeature([], tf.int64),
            "label": tf.FixedLenFeature([], tf.int64),
            "image_raw": tf.FixedLenFeature([], tf.string)
        }
    )

    image = tf.decode_raw(features['image_raw'], tf.uint8)
    label = tf.cast(features["label"], tf.int64)
    pixels = tf.cast(features["pixels"], tf.int32)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for i in range(6):
        image_raw, image_label, _ = sess.run([image, label, pixels])
        image_raw = np.reshape(image_raw, (image_size, image_size, image_channel))
        print(image_raw.shape)
        plt.imshow(image_raw)
        plt.xlabel(str(image_label))
        plt.show()


if __name__ == '__main__':
    with tf.Session() as sess:
        generate_tfrecord()
        read_from_tfrecord(sess=sess)

3)、基于第二步生成的TFRecord文件构造DataSet

import os
import tensorflow as tf
import cv2


def _parse_tfrecord_features(serialized_feature):
    features = {
        "pixels": tf.FixedLenFeature([], tf.int64),
        "label": tf.FixedLenFeature([], tf.int64),
        "image_raw": tf.FixedLenFeature([], tf.string)
    }
    image_features = tf.parse_single_example(
        serialized_feature,
        features=features
    )
    image_raw = tf.decode_raw(bytes=image_features["image_raw"], out_type=tf.uint8)
    image_raw = tf.image.convert_image_dtype(image_raw, tf.float32)

    labels = tf.cast(image_features["label"], tf.int64)
    image_labels = tf.one_hot(labels, depth=2)

    return image_raw, image_labels

def _image_reshape(image_data, image_label):
    return tf.reshape(image_data, (224, 224, 3)), image_label

current_path = os.path.dirname(os.path.abspath(__file__))
train_tfrecords = os.path.join(current_path, "tfrecord_files", "cat_and_dog_train.tfrecords")
validation_tfrecords = os.path.join(current_path, "tfrecord_files", "cat_and_dog_validation.tfrecords")
repeat_count = 3
batch_size = 16
shuffle_buffer_size = 4096
train_epoch = 10
image_size = 224
image_channel = 3

train_tfrecord_datasets = tf.data.TFRecordDataset(filenames=train_tfrecords)
train_tfrecord_datasets = train_tfrecord_datasets.map(_parse_tfrecord_features)
train_tfrecord_datasets = train_tfrecord_datasets.repeat(count=repeat_count)
train_tfrecord_datasets = train_tfrecord_datasets.map(_image_reshape)

train_tfrecord_datasets = train_tfrecord_datasets.shuffle(buffer_size=shuffle_buffer_size,
                                                          reshuffle_each_iteration=True)
train_tfrecord_datasets_batch = train_tfrecord_datasets.batch(batch_size=batch_size)

validation_tfrecord_datasets = tf.data.TFRecordDataset(filenames=validation_tfrecords)
validation_tfrecord_datasets = validation_tfrecord_datasets.map(_parse_tfrecord_features)
validation_tfrecord_datasets = validation_tfrecord_datasets.map(_image_reshape)
validation_tfrecord_datasets = validation_tfrecord_datasets.shuffle(buffer_size=shuffle_buffer_size, reshuffle_each_iteration=True)
validation_tfrecord_datasets_batch = validation_tfrecord_datasets.batch(batch_size=batch_size)

iterator = tf.data.Iterator.from_structure(
    output_types=train_tfrecord_datasets_batch.output_types,
    output_shapes=train_tfrecord_datasets_batch.output_shapes
)

train_init_op = iterator.make_initializer(dataset=train_tfrecord_datasets_batch)
validation_init_op = iterator.make_initializer(dataset=validation_tfrecord_datasets_batch)

next_batch = iterator.get_next()

if __name__ == '__main__':
    with tf.Session() as sess:
        for epoch in range(train_epoch):
            sess.run(train_init_op)
            print("#" * 30, "training", "#" * 30)
            while True:
                try:
                    train_data_batch, train_label_batch = sess.run(next_batch)
                    print(train_label_batch)
                except tf.errors.OutOfRangeError as e:
                    # print(e)
                    break

            print("#" * 30, "validation", "#" * 30)
            sess.run(validation_init_op)
            while True:
                try:
                    validation_data_batch, validation_label_batch = sess.run(next_batch)
                    print(validation_label_batch)
                except tf.errors.OutOfRangeError as e:
                    # print(e)
                    break

其中几个要点需要说明一下:

1)、_parse_tfrecord_features函数用于解析TFRecord文件中图像信息,也就是需要将TFRecord中保存的图像矩阵和图像标签one-hot编码解析出来

2)、_image_reshape函数是将解析后的图像矩阵还原为3维(2242243)的形状

3)、train_tfrecord_datasets.map(_parse_tfrecord_features)train_tfrecord_datasets.map(_image_reshape)

函数是分别对dataset中的每个元素应用一次_parse_tfrecord_features_image_reshape,从而符合神经网络输入的DataSet

4)、由于训练数据和验证数据具有相同的结构,所以tf.data.Iterator.from_structure来构造迭代器对象,注意用的是train_tfrecord_datasets_batch

iterator = tf.data.Iterator.from_structure(
    output_types=train_tfrecord_datasets_batch.output_types,
    output_shapes=train_tfrecord_datasets_batch.output_shapes
)

5)、以下两步是对训练数据Dataset和验证数据DataSet进行初始化

train_init_op = iterator.make_initializer(dataset=train_tfrecord_datasets_batch)
validation_init_op = iterator.make_initializer(dataset=validation_tfrecord_datasets_batch)

6)、当需要使用训练数据时,先执行sess.run(train_init_op)进行初始化,然后再执行train_data_batch, train_label_batch = sess.run(next_batch)即可获得训练用的batch;同理,需要使用验证数据时,先执行sess.run(validation_init_op),然后再执行validation_data_batch, validation_label_batch = sess.run(next_batch)即可获得验证用的batch.

注意:随着iterator不断的取next,当取到最后的时候会抛出tf.errors.OutOfRangeError异常,表示数据已经取完了,此时如果想要进行下一轮的取数据则需要执行相应的初始化操作(train_init_opvalidation_init_op_batch)即可获得验证用的batch.

注意:随着iterator不断的取next,当取到最后的时候会抛出tf.errors.OutOfRangeError异常,表示数据已经取完了,此时如果想要进行下一轮的取数据则需要执行相应的初始化操作(train_init_opvalidation_init_op

你可能感兴趣的:(深度学习,python编程)