自用
"""Converts PASCAL dataset to TFRecords file format.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import io import os import sys import natsort import PIL.Image import tensorflow as tf from utils import dataset_util parser = argparse.ArgumentParser() parser.add_argument('--data_dir', type=str, default='/home/a/dataset/cityscapes/', help='Path to the directory containing the PASCAL VOC data.') parser.add_argument('--output_path', type=str, default='./dataset', help='Path to the directory to create TFRecords outputs.') parser.add_argument('--train_data_list', type=str, default='./dataset/train.txt', help='Path to the file listing the training data.') parser.add_argument('--valid_data_list', type=str, default='./dataset/val.txt', help='Path to the file listing the validation data.') parser.add_argument('--image_data_dir', type=str, default='leftImg8bit', help='The directory containing the image data.') parser.add_argument('--label_data_dir', type=str, default='gtFine', help='The directory containing the augmented label data.') def dict_to_tf_example(image_path, label_path): """Convert image and label to tf.Example proto. Args: image_path: Path to a single PASCAL image. label_path: Path to its corresponding label. Returns: example: The converted tf.Example. Raises: ValueError: if the image pointed to by image_path is not a valid JPEG or if the label pointed to by label_path is not a valid PNG or if the size of image does not match with that of label. """ with tf.gfile.GFile(image_path, 'rb') as fid: encoded_jpg = fid.read() encoded_jpg_io = io.BytesIO(encoded_jpg) image = PIL.Image.open(encoded_jpg_io) if image.format != 'PNG': raise ValueError('Image format not PNG') with tf.gfile.GFile(label_path, 'rb') as fid: encoded_label = fid.read() encoded_label_io = io.BytesIO(encoded_label) label = PIL.Image.open(encoded_label_io) if label.format != 'PNG': raise ValueError('Label format not PNG') if image.size != label.size: raise ValueError('The size of image does not match with that of label.') width, height = image.size example = tf.train.Example(features=tf.train.Features(feature={ 'image/height': dataset_util.int64_feature(height), 'image/width': dataset_util.int64_feature(width), 'image/encoded': dataset_util.bytes_feature(encoded_jpg), 'image/format': dataset_util.bytes_feature('png'.encode('utf8')), 'label/encoded': dataset_util.bytes_feature(encoded_label), 'label/format': dataset_util.bytes_feature('png'.encode('utf8')), })) return example def scanDir_img_File(dir): for root, dirs, files in os.walk(dir, True, None, False): # 遍列目录 for f in files: yield os.path.join(root,f) def scanDir_lable_File(dir): for root, dirs, files in os.walk(dir, True, None, False): # 遍列目录 # 处理该文件夹下所有文件: for f in files: if os.path.isfile(os.path.join(root, f)): a = os.path.splitext(f) lable = a[0].split('_')[4] # print(lable) if lable in ('labelTrainIds'): # print(os.path.join(root,f)) yield os.path.join(root,f) def create_tf_record(output_filename, image_dir, label_dir): """Creates a TFRecord file from examples. Args: output_filename: Path to where output file is saved. image_dir: Directory where image files are stored. label_dir: Directory where label files are stored. """ imgg = [] writer = tf.python_io.TFRecordWriter(output_filename) img = scanDir_img_File(image_dir) for imgs in img: imgg.append(imgs) image_list = natsort.natsorted(imgg) lable = scanDir_lable_File(label_dir) lablee = [] for lables in lable: lablee.append(lables) label_list = natsort.natsorted(lablee) for image_path,label_path in zip(image_list,label_list): print(image_path,label_path) try: tf_example = dict_to_tf_example(image_path, label_path) writer.write(tf_example.SerializeToString()) except ValueError: tf.logging.warning('Invalid example: %s, ignoring.') writer.close() def main(unused_argv): if not os.path.exists(FLAGS.output_path): os.makedirs(FLAGS.output_path) tf.logging.info("Reading from CITY dataset") train_image_dir = os.path.join(FLAGS.data_dir, FLAGS.image_data_dir,'train') train_label_dir = os.path.join(FLAGS.data_dir, FLAGS.label_data_dir,'train') val_image_dir = os.path.join(FLAGS.data_dir, FLAGS.image_data_dir, 'val') val_label_dir = os.path.join(FLAGS.data_dir, FLAGS.label_data_dir, 'val') train_output_path = os.path.join(FLAGS.output_path, 'city_train.record') val_output_path = os.path.join(FLAGS.output_path, 'city_val.record') create_tf_record(train_output_path, train_image_dir, train_label_dir) create_tf_record(val_output_path, val_image_dir, val_label_dir) if __name__ == '__main__': tf.logging.set_verbosity(tf.logging.INFO) FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)