VOC07+12合并训练用于图像分类

该程序可以用于图像分类,比如Voc2012的数据集,首先将数据集中标签和图片按照类别分为20类,然后在使用该程序将分好类的文件,转换为tfrecord格式,用于训练图像分类。具体程序如下:

# -*- coding: utf-8 -*-
"""
Created on Sat Mar  9 13:22:18 2019

import os
import sys
import random
import numpy as np
import tensorflow as tf

import xml.etree.ElementTree as ET
import six
from six.moves import cPickle

sys.path.append(r".....\datasets") 
sys.path.append(r".....\utils") 
#import label_map_util
#import dataset_utils

DIRECTORY_ANNOTATIONS = 'Annotations/'
DIRECTORY_IMAGES = 'JPEGImages/'
RANDOM_SEED = 4242
SAMPLES_PER_FILES = 2000

VOC_LABELS = {
    'aeroplane': (1, 'Vehicle'),
    'bicycle': (2, 'Vehicle'),
    'bird': (3, 'Animal'),
    'boat': (4, 'Vehicle'),
    'bottle': (5, 'Indoor'),
    'bus': (6, 'Vehicle'),
    'car': (7, 'Vehicle'),
    'cat': (8, 'Animal'),
    'chair': (9, 'Indoor'),
    'cow': (10, 'Animal'),
    'diningtable': (11, 'Indoor'),
    'dog': (12, 'Animal'),
    'horse': (13, 'Animal'),
    'motorbike': (14, 'Vehicle'),
    'person': (15, 'Person'),
    'pottedplant': (16, 'Indoor'),
    'sheep': (17, 'Animal'),
    'sofa': (18, 'Indoor'),
    'train': (19, 'Vehicle'),
    'tvmonitor': (20, 'Indoor'),
}

def int64_feature(values):
    """Returns a TF-Feature of int64s.

    Args:
    values: A scalar or list of values.

    Returns:
    a TF-Feature.
    """
    if not isinstance(values, (tuple, list)):
        values = [values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

def float_feature(value):
    """Wrapper for inserting float features into Example proto.
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def bytes_feature(value):
    """Wrapper for inserting bytes features into Example proto.
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def _bytes_feature(value):
  """Wrapper for inserting bytes features into Example proto."""
  if six.PY3 and isinstance(value, six.text_type):           
    value = six.binary_type(value, encoding='utf-8') 
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

SPLIT_MAP = ['train', 'val', 'trainval','test']
    
"""
Process a image and annotation file.

Args:
    filename:       string, path to an image file e.g., '/path/to/example.JPG'.
    coder:          instance of ImageCoder to provide TensorFlow image coding utils.

Returns:
    image_buffer:   string, JPEG encoding of RGB image.
    height:         integer, image height in pixels.
    width:          integer, image width in pixels.
读取一个样本图片及对应信息
"""
def _process_image(directory, name):
    # Read the image file.

    filename = os.path.join(directory, DIRECTORY_IMAGES, name + '.jpg')
    image_data = tf.gfile.FastGFile(filename, 'rb').read()
    # Read the XML annotation file.

    filename = os.path.join(directory, DIRECTORY_ANNOTATIONS, name + '.xml')
    tree = ET.parse(filename)
    root = tree.getroot()
    # Image shape.#
    
    size = root.find('size')
    shape = [int(size.find('height').text), int(size.find('width').text), int(size.find('depth').text)]
    # Find annotations.
    # 获取每个object的信息
    
    bboxes = []
    labels = []
    labels_text = []
    poses = []
    i =1
    for obj in root.findall('object'):
        #计算总公共有几个object,若等于1,则直接标识;否则,判断标识是否相同,如果相同,则不变
            label = obj.find('name').text
            
            label_id = int(int(VOC_LABELS[label][0]))
            labels.append(int(VOC_LABELS[label][0])) #添加标识
            
            labels_text.append(label.encode('ascii'))#添加accii

            bbox = obj.find('bndbox')
            bboxes.append((float(bbox.find('ymin').text) / shape[0],
                       float(bbox.find('xmin').text) / shape[1],
                       float(bbox.find('ymax').text) / shape[0],
                       float(bbox.find('xmax').text) / shape[1]
                       ))
            pose = obj.find('pose').text
            poses.append(pose.encode('ascii'))
            
            
    return image_data, shape, bboxes, labels, labels_text, label_id, poses, name_num
         

"""
Build an Example proto for an image example.

Args:
  image_data: string, JPEG encoding of RGB image;
  labels: list of integers, identifier for the ground truth;
  labels_text: list of strings, human-readable labels;
  bboxes: list of bounding boxes; each box is a list of integers;
      specifying [xmin, ymin, xmax, ymax]. All boxes are assumed to belong
      to the same label as the image label.
  shape: 3 integers, image shapes in pixels.
Returns:
  Example proto
将一个图片及对应信息按格式转换成训练时可读取的一个样本
"""
def _convert_to_example(image_data, labels, labels_text, bboxes, shape, label_id, poses, name):#
    xmin = []
    ymin = []
    xmax = []
    ymax = []
    for b in bboxes:
        assert len(b) == 4
        # pylint: disable=expression-not-assigned
        [l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)]
        # pylint: enable=expression-not-
        
    image_format = b'JPEG'
    colorspace = 'RGB'
    channels = 3
    example = tf.train.Example(features=tf.train.Features(feature={
            'image/height': int64_feature(shape[0]),
            'image/width': int64_feature(shape[1]),
            'image/colorspace': _bytes_feature(colorspace),
            'image/channels': int64_feature(channels),
            'image/label': int64_feature(label_id),
            'image/object/bbox/xmin': float_feature(xmin),
            'image/object/bbox/xmax': float_feature(xmax),
            'image/object/bbox/ymin': float_feature(ymin),
            'image/object/bbox/ymax': float_feature(ymax),
            'image/object/class/label': int64_feature(labels), 
            'image/object/bbox/label_text': bytes_feature(labels_text),
            'image/object/view': bytes_feature(poses),
            'image/format': bytes_feature(image_format),
            'image/filename': _bytes_feature(name),
            'image/encoded': bytes_feature(image_data)}))
    #print()
    #print(example)
    return example

"""
Loads data from image and annotations files and add them to a TFRecord.

Args:
  dataset_dir: Dataset directory;
  name: Image name to add to the TFRecord;
  tfrecord_writer: The TFRecord writer to use for writing.
"""
def _add_to_tfrecord(dataset_dir, filename, label_map_path, tfrecord_writer):
    image_data, shape, bboxes, labels, labels_text, label_id, poses ,name = \
        _process_image(dataset_dir, filename)
           
    example = _convert_to_example(image_data, 
                                  labels,
                                  labels_text,
                                  bboxes, 
                                  shape, label_id, poses, name)

    tfrecord_writer.write(example.SerializeToString())
#split为main文件夹中train或者是trainval,亦或者是其他的名称;具体情况看自己如何使用
def run(voc_root, year, split, output_dir, out_name, label_map_path, shuffling=True):
    # 如果output_dir不存在则创建
    if not tf.gfile.Exists(output_dir):
        tf.gfile.MakeDirs(output_dir)
    # VOCdevkit/VOC2012/ImageSets/Main/train.txt
    # 中存放有所有20个类别的训练样本名称,共5717个
    split_file_path = os.path.join(voc_root,'VOC%s'%year,'ImageSets','Main','%s.txt'%split)
    print ('>> ', split_file_path)
    with open(split_file_path) as f:
        filenames = f.readlines()
        
    # shuffling == Ture时,打乱顺序
    if shuffling:
        random.seed(RANDOM_SEED)
        random.shuffle(filenames)
        
    # Process dataset files.
    i = 0
    fidx = 0
    dataset_dir = os.path.join(voc_root, 'VOC%s'%year)
    while i < len(filenames):
        # Open new TFRecord file.  
        tf_filename = '%s/VOC0712_%s%02d.tfrecord' % (output_dir, out_name, fidx)#
        with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
            j = 0
            while i < len(filenames) and j < SAMPLES_PER_FILES:
                sys.stdout.write('\r>> Converting image %d/%d' % (i+1, len(filenames)))
                sys.stdout.flush()
                filename = filenames[i].strip()
                _add_to_tfrecord(dataset_dir, filename, label_map_path, tfrecord_writer)
                i += 1
                j += 1
            fidx += 1
    # Finally, write the labels file: 
    abels_to_class_names = dict(zip(range(len(VOC_LABELS)), VOC_LABELS))
    dataset_utils.write_label_file(labels_to_class_names, output_dir)
    print('\n>> Finished converting the Pascal VOC dataset!')
    
dataset_dir="......../VOCdevkit/"
output_dir="........./VOC0712/"
name="trainval"  #
out_name='train'#train
label_map_path = 'labels'
def main(_):
    run(dataset_dir, 1207, name, output_dir, out_name, label_map_path)

if __name__ == '__main__':
  tf.app.run()

你可能感兴趣的:(程序,图像分类)