手上有TT100K数据集,图片标注信息是json文件(作者用Qt开发的标注软件标注),但是想用Tensorflow Object Detection API来训练,之前做的Demo都是xml格式的标注,那么,如何将TT100K转成想要的TFRecord呢?
参考上一篇博客(VOC数据集转换成TFRecord文件):https://blog.csdn.net/m0_37970224/article/details/89305787
做下改动~
主要区别就是之前是读取xml文件的内容,现在改成读取json数据里面的内容~
上代码:tt100k_to_tfrecord.py
# coding=utf-8
import os
import sys
import random
import tensorflow as tf
import json
from PIL import Image
# DIRECTORY_IMAGES = './train/'
DIRECTORY_IMAGES = './test/'
RANDOM_SEED = 4242
SAMPLES_PER_FILES = 1600
def int64_feature(values):
"""Returns a TF-Feature of int64s.
Args:
values: A scalar or list of values.
Returns:
a TF-Feature.
"""
if not isinstance(values, (tuple, list)):
values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
def float_feature(value):
"""Wrapper for inserting float features into Example proto.
"""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def bytes_feature(value):
"""Wrapper for inserting bytes features into Example proto.
"""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def _process_image(directory, name):
# Read the image file.
filename = os.path.join(directory, DIRECTORY_IMAGES, name + '.jpg')
image_data = tf.gfile.FastGFile(filename, 'rb').read()
# Read the json annotation file.
filedir = directory + "/annotations.json"
annos = json.loads(open(filedir).read())
annos['imgs'][name]
# shape
with Image.open(filename) as img:
shape = [img.height, img.width, 3]
# 获取每个object的信息
bboxes = []
labels = []
labels_text = []
for obj in annos['imgs'][name]['objects']:
label = obj['category']
labels.append(annos['types'].index(label) + 1)
labels_text.append(label.encode('utf8'))
bbox = obj['bbox']
bboxes.append((float(bbox['ymin']) / shape[0],
float(bbox['xmin']) / shape[1],
float(bbox['ymax']) / shape[0],
float(bbox['xmax']) / shape[1]
))
return image_data, shape, bboxes, labels, labels_text
def _convert_to_example(image_data, labels, labels_text, bboxes, shape):
xmin = []
ymin = []
xmax = []
ymax = []
for b in bboxes:
assert len(b) == 4
# pylint: disable=expression-not-assigned
[l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)]
# pylint: enable=expression-not-assigned
image_format = b'JPEG'
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': int64_feature(shape[0]),
'image/width': int64_feature(shape[1]),
'image/channels': int64_feature(shape[2]),
'image/shape': int64_feature(shape),
'image/object/bbox/xmin': float_feature(xmin),
'image/object/bbox/xmax': float_feature(xmax),
'image/object/bbox/ymin': float_feature(ymin),
'image/object/bbox/ymax': float_feature(ymax),
'image/object/class/label': int64_feature(labels),
'image/object/class/text': bytes_feature(labels_text),
'image/format': bytes_feature(image_format),
'image/encoded': bytes_feature(image_data)}))
return example
def _add_to_tfrecord(dataset_dir, name, tfrecord_writer):
image_data, shape, bboxes, labels, labels_text = \
_process_image(dataset_dir, name)
print(shape, bboxes, labels, labels_text)
example = _convert_to_example(image_data,
labels,
labels_text,
bboxes,
shape)
tfrecord_writer.write(example.SerializeToString())
def run(tt100k_root, split, output_dir, shuffling=False):
# 如果output_dir不存在则创建
if not tf.gfile.Exists(output_dir):
tf.gfile.MakeDirs(output_dir)
# TT100K/data/train/ids.txt
# 中存放有所有221个类别的训练样本名称,共6105个
split_file_path = os.path.join(tt100k_root, split, 'ids.txt')
print('>> ', split_file_path)
with open(split_file_path) as f:
filenames = f.readlines()
# shuffling == Ture时,打乱顺序
if shuffling:
random.seed(RANDOM_SEED)
random.shuffle(filenames)
# Process dataset files.
i = 0
fidx = 0
while i < len(filenames):
# Open new TFRecord file.
tf_filename = '%s/%s_%03d.tfrecord' % (output_dir, 'test', fidx)
with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
j = 0
while i < len(filenames) and j < SAMPLES_PER_FILES:
sys.stdout.write('\r>> Converting image %d/%d' % (i + 1, len(filenames)))
sys.stdout.flush()
filename = filenames[i].strip()
_add_to_tfrecord(tt100k_root, filename, tfrecord_writer)
i += 1
j += 1
fidx += 1
print('\n>> Finished converting the TT100K %s dataset!' % (split))
if __name__ == '__main__':
run('E:\data\TT100K\data', 'test', './data/tt100k/test')
模仿着改,因你数据而异,也能够转成tfrecord文件~