tfrecords文件生成与读取

tfrecords文件生成与读取

以分类数据为例。

tfrecords文件生成

tfreocrds数据将原始图像数据和标签数据以二进制格式存储。存储内容以如下形式存储:

example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
"height": tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0]])),
"width": tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[1]])),
"channel": tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[2]])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
})) 

存储代码如下:

"""
2018.7.20 --
Use this code to create the dataset for classification.
first write your own class_dic according to the folder.
then set the root path for all class folders in "data_path"
:parameter classes: the name of folder for different classes' images
:parameter writer: the filename of tfrecord
"""


import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import scipy.misc as misc
import numpy as np


cwd = os.getcwd()

class_dic = {
    "airplane": 0,
    "ship": 1,
    "background": 2
}


def tf_record_create(cwd, classes, writer):
    """

    :param cwd: filepath which contains all classes(one class, one folder.
    :param classes: {'class_name1', 'class_name2', ...}
    :param writer: tf.python_io.TFRecordWriter("*****.tfrecords")):
    :return:
    """
    for index, name in enumerate(classes):
        label = class_dic[name]
        class_path = cwd+'/'+name+'/'
        for img_name in os.listdir(class_path):
            img_path = class_path+img_name #每一个图片的地址

            img_pil = Image.open(img_path)
            img = np.array(img_pil)
            # img = img.resize((IMG_HEIGHT, IMG_WIDTH))
            # instead of resize, get the image shape and write in the example
            shape_debug = img.shape
            shape = list(img.shape)
            if len(shape) == 2:
                shape.append(1)
            shape = np.array(shape, np.int64)
            img_raw = img.tobytes()  # 将图片转化为二进制格式
            example = tf.train.Example(features=tf.train.Features(feature={
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
                "height": tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0]])),
                "width": tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[1]])),
                "channel": tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[2]])),
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
            }))  # example对象对label和image数据进行封装
            writer.write(example.SerializeToString())  # 序列化为字符串

    writer.close()
    # print("image to tfrecord processed")
    return


if __name__ == '__main__':
    # training data creation
    data_path = "../DATA/train"
    tf_record_create(cwd=data_path,
                    classes=['airplane', 'ship', 'background'],
                    writer=tf.python_io.TFRecordWriter("xingtu_cls_4360train.tfrecords"))
    print("creating processed")

tfrecords文件读取

tfrecords在读取时,根据存储时feature字典依次读取其内内容。

features = tf.parse_single_example(serialized_example,
                                features={
                                    'label': tf.FixedLenFeature([], tf.int64),
                                    'height': tf.FixedLenFeature([], tf.int64),
                                    'width': tf.FixedLenFeature([], tf.int64),
                                    'channel': tf.FixedLenFeature([], tf.int64),
                                    'img_raw' : tf.FixedLenFeature([], tf.string),
                                })

整体代码如下:

"""
2018.7.20 --
Use this code to create the dataset for classification.
first write your own class_dic according to the folder.
then set the root path for all class folders in "data_path"
:parameter classes: the name of folder for different classes' images
:parameter writer: the filename of tfrecord
"""


import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import scipy.misc as misc
import numpy as np


cwd = os.getcwd()

class_dic = {
    "airplane": 0,
    "ship": 1,
    "background": 2
}


def tf_record_read_and_save(filepath):
    filename_queue = tf.train.string_input_producer([filepath])  # 读入流中
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)   # 返回文件名和文件

    features = tf.parse_single_example(serialized_example,
                                    features={
                                        'label': tf.FixedLenFeature([], tf.int64),
                                        'height': tf.FixedLenFeature([], tf.int64),
                                        'width': tf.FixedLenFeature([], tf.int64),
                                        'channel': tf.FixedLenFeature([], tf.int64),
                                        'img_raw' : tf.FixedLenFeature([], tf.string),
                                    })  # 将image数据和label取出来

    h = tf.cast(features['height'], tf.int32)
    w = tf.cast(features['width'], tf.int32)
    c = tf.cast(features['channel'], tf.int32)

    image = tf.decode_raw(features['img_raw'], tf.uint8)
    image = tf.reshape(image, [h, w, c])

    label = tf.cast(features['label'], tf.int32)
    label = tf.reshape(label, [1])

    with tf.Session() as sess:  # 开始一个会话
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        for i in range(1):
            example, l = sess.run([image, label])  # 在会话中取出image和label

            h, w, c = sess.run([h, w, c])
            print("(h, w, c) = {}, {}, {}".format(h, w, c))
            print("image's shape:", example.shape)

            img = np.array(np.squeeze(example))
            img = Image.fromarray(img, 'RGB' if example.shape[2] == 3 else 'L')
            img.save(str(i) + '_Label_' + str(l) + '.jpg')  # 存下图片
            # print(example, l)
        coord.request_stop()
        coord.join(threads)


if __name__ == '__main__':
    tf_record_read_and_save(filepath="xingtu_cls_4360train.tfrecords",)

注:tfrecords生成与读取检查的代码为cls_tfrecord_create.py.

从tfrecords文件中读取数据并构建数据batch

代码为/data_made/tfrecord_read_and_show/demo_tfrecord_read_c1.py

Two ways to read in the tfrecord dataset: 1. next_batch 2. dataset.next.

  1. next batch:
    Using tf.train.string_input_producer(), read_single_example_and_decode(), preprocess_image() and tf.train.batch() to create the (image, label) tensors.

  2. dataset
    Using tf.data.TFRecordDataset(), dataset = dataset.map(_parser) and dataset.’method’ to create the (image, label) tensors.

代码如下:

import tensorflow as tf
import os
import numpy as np
import matplotlib.pyplot as plt

_HEIGHT = 256
_WIDTH = 256
_CHANNELS = 1

# preprocessing parameters
random_extend_ratio = 1.2
random_contrast_lower = 0.3
random_contrast_upper = 1.0
random_brightness_max_delta = 0.5


def preprocess_image(image, is_training):
    if is_training:
        image = tf.image.resize_images(images=image,
                                    size=[tf.cast(_HEIGHT * random_extend_ratio, tf.int32),
                                            tf.cast(_WIDTH * random_extend_ratio, tf.int32)])

        image = tf.random_crop(image, [_HEIGHT, _WIDTH, _CHANNELS])

        # flip
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)

        # adjust hue, contrast, saturation, bright(hue and saturation are not supported for one channel gray image)
        image = tf.image.random_contrast(image, lower=random_contrast_lower, upper=random_contrast_upper)
        image = tf.image.random_brightness(image, max_delta=random_brightness_max_delta)

    else:
        # according to the test, resize_images & resize_area have the same resize function which didn't appear the
        # problem mentioned by the blog(), while resize_bicubic() has a side effect when align_corners is setted as True.
        image = tf.image.resize_images(images=image,
                                    size=[_HEIGHT, _WIDTH])

    image = tf.image.per_image_standardization(image)
    return image


def read_single_example_and_decode(filename_queue):

    # reader = tf.TFRecordReader(options=tfrecord_options)
    reader = tf.TFRecordReader()

    _, serialized_example = reader.read(filename_queue)

    features = tf.parse_single_example(
        serialized=serialized_example,
        features = {
            'height': tf.FixedLenFeature([], tf.int64),
            'width': tf.FixedLenFeature([], tf.int64),
            'channel': tf.FixedLenFeature([], tf.int64),
            'img_raw': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64)
        }
    )
    img_height = tf.cast(features['height'], tf.int32)
    img_width = tf.cast(features['width'], tf.int32)
    img_channel = tf.cast(features['channel'], tf.int32)

    img = tf.decode_raw(features['img_raw'], tf.uint8)

    img = tf.reshape(img, shape=[img_height, img_width, img_channel])

    label = tf.cast(features['label'], tf.int32)

    return img, label


def next_batch(dataset_name, batch_size, is_training):
    if dataset_name == "xingtu":
        pattern = "../tfrecords/xingtu_cls_63test.tfrecords"
    else:
        raise ValueError("xingtu only")
    print('tfrecord path is -->', os.path.abspath(pattern))

    # filename_tensorlist = tf.train.match_filenames_once(pattern)
    filename_queue = tf.train.string_input_producer([pattern])

    image, label = read_single_example_and_decode(filename_queue)

    image = preprocess_image(image, is_training)

    img_batch, label_batch = tf.train.batch([image, label],
                                            batch_size=batch_size,
                                            capacity=32,
                                            num_threads=4,
                                            dynamic_pad=True)
    return img_batch, label_batch
    # return image, label


# obtain the mask for seg
def input_fn(filename, is_training, batch_size, shuffle_buffer, num_epochs=1):
    ##
    if os.path.exists(filename):
        pass
    else:
        raise ValueError("not such file exists")

    def _parser(example_proto):
        features = {
            'height': tf.FixedLenFeature([], tf.int64),
            'width': tf.FixedLenFeature([], tf.int64),
            'channel': tf.FixedLenFeature([], tf.int64),
            'img_raw': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64)
        }
        parsed_features = tf.parse_single_example(example_proto, features=features)

        height = tf.cast(parsed_features['height'], tf.int32)
        width = tf.cast(parsed_features['width'], tf.int32)
        c = tf.cast(parsed_features['channel'], tf.int32)

        image = tf.decode_raw(parsed_features['img_raw'], tf.uint8)
        image = tf.reshape(image, [height, width, c])
        image = preprocess_image(image, is_training)

        label = tf.cast(parsed_features['label'], tf.int32)

        return image, label

    dataset = tf.data.TFRecordDataset(filename)
    dataset = dataset.prefetch(buffer_size=batch_size)
    if is_training:
        dataset = dataset.shuffle(buffer_size=shuffle_buffer)
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.map(_parser)
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_one_shot_iterator()
    next_image, next_label = iterator.get_next()

    return next_image, next_label


if __name__ == '__main__':
    image, label = next_batch(dataset_name="xingtu", batch_size=1, is_training=True)
    # image, label = input_fn("./data/one_image.tfrecords",
    #                         is_training=False, batch_size=1, shuffle_buffer=1, num_epochs=1)

    tf.summary.image("image", image)
    summary_op = tf.summary.merge_all()

    init = tf.group(tf.global_variables_initializer(),
                    tf.local_variables_initializer())
    global_step = tf.train.get_or_create_global_step()
    with tf.Session() as sess:
        writer = tf.summary.FileWriter("./sar_summary", sess.graph)
        sess.run(init)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        img_ = sess.run(image)
        img_show = np.array(np.squeeze(img_,), dtype=np.uint8)

        plt.figure()
        plt.imshow(img_show)
        plt.show()

        summary = sess.run(summary_op)
        writer.add_summary(summary, 0)
        coord.request_stop()
        coord.join(threads)

你可能感兴趣的:(卷积分类网络)