从LabelImg建立TFRecord格式的训练样本

文章目录

  • 图像样本转化为TFRecord
    • 图像数据处理
      • 用OpenCV读取图像数据
      • 用TensorFlow读取图像数据
    • TFRecord文件的数据读写方法
      • TFRecord的写入
      • TFRecord的读出
    • LabelImg标注文件的格式与解析
    • 将LabelImg的标注及对应的图像文件保存为TFRecord
      • 验证

完整内容见 作者的另一篇博客

图像样本转化为TFRecord

图像数据处理

用OpenCV读取图像数据

用OpenCV读取图像数据的格式为numpy.ndarray

import cv2

# 读取分辩率为1920*1080的3通道图像文件
img = cv2.imread("d:/temp/test_folder/image2.jpg", cv2.IMREAD_COLOR)  # cv2.IMREAD_COLOR值为1
print(img.shape)  # 输出: (1080, 1920, 3)
print(type(img))  # 输出: 
print(img.dtype)  # 输出: uint8

# 修改宽高
img = cv2.resize(img, (480, 270))  # 注意dsize参数的元组按目标宽高排列

# 将BGR格式转为TensorFlow常用的RGB格式
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# 显示图像
cv2.imshow("bgr", img)
cv2.imshow("rgb", img_rgb)
cv2.waitKey(0)

上述代码显示图像如下:
从LabelImg建立TFRecord格式的训练样本_第1张图片

从LabelImg建立TFRecord格式的训练样本_第2张图片

用TensorFlow读取图像数据

用TensorFlow读入图像文件的类型为bytes,即string类型

解析图像文件必须用Session运行tensor才能执行,而且必须注意float类型的归一化及其与uint8的转换

import cv2
import tensorflow as tf

# 读取分辩率为1920*1080的3通道图像文件
with tf.gfile.FastGFile("d:/temp/test_folder/image2.jpg", "rb") as fid:
    img = fid.read()
    print(type(img))  # 输出: 

    # 建立解析jpg图像为原始数据的tensor
    img_tensor = tf.image.decode_jpeg(img)
    print(type(img_tensor))  # 输出: 
    print(img_tensor.shape)  # 输出: (?, ?, ?)
    print(img_tensor.dtype)  # 输出: 

    # 指定tensor的shape
    # 无论是否设置shape,decode_jpeg都能正确解析图像
    img_tensor.set_shape([1080, 1920, 3])
    print(img_tensor.shape)  # 输出: (1080, 1920, 3)

    # 修改宽高
    # 特别注意resize_images输出类型为float32,即将uint8转化为float32类型
    # resize_images即可以操作图像batch的4D数据,也可以转化单幅图像的3D数据
    img_resize_float = tf.image.resize_images(img_tensor, size=(270, 480))
    print(type(img_resize_float))  # 输出: 
    print(img_resize_float.shape)  # 输出: (270, 480, 3),如果img_tensor未指定shape,则输出为(270, 480, ?)
    print(img_resize_float.dtype)  # 输出: 

    # 转换类型为uint8
    # 特别注意convert_image_dtype中float32类型的图像都必须归一化至0~1范围
    # img_resize_float/255用于归一化至0~1的范围
    img_resize_uint = tf.image.convert_image_dtype(img_resize_float/255, dtype=tf.uint8, saturate=True)

    # 用Session完成tensor的计算
    with tf.Session() as sess:
        img_rgb_float = sess.run(img_resize_float)
        print(type(img_rgb_float))  # 输出: 
        print(img_rgb_float.shape)  # 输出: (270, 480, 3)
        print(img_rgb_float.dtype)  # 输出: float32

        img_rgb = sess.run(img_resize_uint)
        print(type(img_rgb))  # 输出: 
        print(img_rgb.shape)  # 输出: (270, 480, 3)
        print(img_rgb.dtype)  # 输出: uint8

        # 用OpenCV显示生成的图像
        img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)

        cv2.imshow("float rgb", img_rgb_float/255)  # imshow显示float类型图像,必须归一化至0~1范围
        cv2.imshow("bgr", img_bgr)
        cv2.imshow("rgb", img_rgb)
        cv2.waitKey(0)

上述代码是显示图像如下:
从LabelImg建立TFRecord格式的训练样本_第3张图片
从LabelImg建立TFRecord格式的训练样本_第4张图片
从LabelImg建立TFRecord格式的训练样本_第5张图片

TFRecord文件的数据读写方法

TFRecord的写入

import tensorflow as tf
"""
TFRecord的写入主要在于以下各类
tf.python_io.TFRecordWriter: tf.python_io.TFRecordWriter.write()执行写操作
tf.train.Example: tf.train.Example.SerializeToString()生成写操作对象
tf.train.Features: 包含tf.train.Feature
tf.train.Feature: 建立tf.train.Example使用的feature
tf.train.BytesList: 建立字符串列表的feature
tf.train.FloatList: 建立浮点数列表的feature
tf.train.Int64List: 建立整型数列表的feature
"""

# tf.train.Feature支持的3种list类型
# 字符串、浮点数、int64整型数
data_bytes = b"\x11\x22\x33\x44"  # b'\x11"3D'
data_float = 3.2
data_int64 = 0

# 建立5个TFRecord文件保存数据
for i in range(5):
    filename = "d:/temp/test_folder/rec{0}.tfrecords".format(i)
    with tf.python_io.TFRecordWriter(filename) as writer:  # 用with而不是调用writer.close()
        example = tf.train.Example(features=tf.train.Features(
            feature={
                # feature参数输入为字典,key为字符串,值为tf.train.Feature对象
                # tf.train.Feature三种list分别传入不同类型的list
                # bytes_list=tf.train.BytesList(value=字符串列表)
                # float_list=tf.train.FloatList(value=浮点数列表)
                # int64_list=tf.train.Int64List(value=int64整型数列表)
                "bytes": tf.train.Feature(bytes_list=tf.train.BytesList(value=[data_bytes])),
                "float": tf.train.Feature(float_list=tf.train.FloatList(value=[float(i)])),
                "int64": tf.train.Feature(int64_list=tf.train.Int64List(value=[i]))
            }
        ))

        writer.write(example.SerializeToString())
        

TFRecord的读出

import tensorflow as tf
"""
TFRecord的写入主要在于以下:
tf.train.match_filenames_once: 用于查找TFRecord文件
tf.train.string_input_producer: 用于从文件列表中取出单个文件名
tf.TFRecordReader: tf.TFRecordReader.read()读出TFRecord文件中的数据
tf.parse_single_example: 将TFRecord文件中的数据解析为字典
tf.FixedLenFeature: 用于取出TFRecord中的feature
tf.train.Coordinator: 操作线程的关闭
tf.train.start_queue_runners: 启动tf.train.string_input_producer返回的queue的输出,开启工作
"""

# 获取文件列表
files = tf.train.match_filenames_once("d:/temp/test_folder/rec*.tfrecords")

# 创建输入队列
# 重复输出输入文件列表中的所有文件名,除非用参数num_epochs指定每个文件可轮询的次数
# shuffle参数用于控制随机打乱文件排序
# 返回值说明:
# A queue with the output strings. A QueueRunner for the Queue is added to the current Graph's QUEUE_RUNNER collection.
# 用tf.train.start_queue_runners启动queue的输出
queue = tf.train.string_input_producer(files, shuffle=False)
# 如果只有单个文件,则使用单个文件名构成列表,注意[]
# queue = tf.train.string_input_producer(["d:/temp/test_folder/rec0.tfrecords"], shuffle=False)

# 建立TFRecordReader并解析TFRecord文件
reader = tf.TFRecordReader()
_, serialized_example = reader.read(queue)  # tf.TFRecordReader.read()用于读取queue中的下一个文件
rec_features = tf.parse_single_example(  # 返回字典,字典key值即features参数中的key值
    serialized_example,
    features={
        "bytes": tf.FixedLenFeature(shape=[], dtype=tf.string),
        "float": tf.FixedLenFeature(shape=[], dtype=tf.float32),
        "int64": tf.FixedLenFeature(shape=[], dtype=tf.int64)
    }
)

"""
tensor字典
{'bytes': , 
'float': , 
'int64': }
"""
print(rec_features)

with tf.Session() as sess:
    """
    sess.run(tf.global_variables_initializer())
    print(sess.run(files))
    上述代码运行出错,提示如下:
    Attempting to use uninitialized value matching_filenames
    因为tf.train.match_filenames_once使用的是局部变量,非全局变量
    需要改成下方代码才能正确运行
    """
    sess.run(tf.local_variables_initializer())
    print(sess.run(files))  # 打印文件列表

    # 用子线程启动tf.train.string_input_producer生成的queue
    coord = tf.train.Coordinator()  # 用于控制线程结束
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # 读出TFRecord文件内容
    for i in range(10):
        # 每次run都由string_input_producer更新至下一个TFRecord文件
        print(sess.run(rec_features["int64"]))
        print(sess.run(rec_features))

    coord.request_stop()  # 结束线程
    coord.join(threads)  # 等待线程结束

LabelImg标注文件的格式与解析

主要参考:https://www.jianshu.com/p/86894ccaa407

LabelImg为每个标识的图像文件建立一个对应的xml标注文件,一般情况下默认的标注文件名与图像文件名一致。

典型的标注文件内容如下:

<annotation>
	<folder>test_folderfolder>
	<filename>cap_11.jpgfilename>
	<path>D:\temp\test_folder\cap_11.jpgpath>
	<source>
		<database>Unknowndatabase>
	source>
	<size>
		<width>480width>
		<height>270height>
		<depth>3depth>
	size>
	<segmented>0segmented>
	<object>
		<name>playername>
		<pose>Unspecifiedpose>
		<truncated>0truncated>
		<difficult>0difficult>
		<bndbox>
			<xmin>178xmin>
			<ymin>81ymin>
			<xmax>233xmax>
			<ymax>203ymax>
		bndbox>
	object>
	<object>
		<name>playername>
		<pose>Unspecifiedpose>
		<truncated>0truncated>
		<difficult>0difficult>
		<bndbox>
			<xmin>234xmin>
			<ymin>29ymin>
			<xmax>306xmax>
			<ymax>245ymax>
		bndbox>
	object>
annotation>

用以下代码读取LabelImg的标注文件信息:

import glob  # 用于遍历文件夹内的xml文件
import xml.etree.ElementTree as ET  # 用于解析xml文件


# 遍历文件夹内的全部xml文件,1个xml文件描述1个图像文件的标注信息
for f in glob.glob("d:/temp/test_folder/*.xml"):
    # 解析xml文件
    try:
        tree = ET.parse(f)
    except FileNotFoundError:
        print("无法找到xml文件: "+f)
    except ET.ParseError:
        print("无法解析xml文件: "+f)
    else:  # ET.parse()运行正确
        # 取得xml根节点
        root = tree.getroot()

        # 取得图像路径和文件名
        print(root.find("filename").text)
        print(root.find("path").text)

        # 取得图像宽高
        print(int(root.find("size")[0].text))  # width节点的序号为[0]
        print(int(root.find("size")[1].text))  # height节点的序号为[1]

        # 取得bbox
        for obj in root.findall("object"):  # 查找根节点下全部名为object的节点
            print(int(obj[4][0].text))  # bndbox节点的序号为[4]
            print(int(obj[4][1].text))
            print(int(obj[4][2].text))
            print(int(obj[4][3].text))

将LabelImg的标注及对应的图像文件保存为TFRecord

import tensorflow as tf  # 导入TensorFlow
import cv2  # 导入OpenCV
import os  # 用于文件操作
import glob  # 用于遍历文件夹内的xml文件
import xml.etree.ElementTree as ET  # 用于解析xml文件


# 将LabelImg标注的图像文件和标注信息保存为TFRecord
class LabelImg2TFRecord:

    @classmethod
    def gen(cls, path):
        """
        :param path: LabelImg标识文件的路径,及生成的TFRecord文件路径
        """
        # 遍历文件夹内的全部xml文件,1个xml文件描述1个图像文件的标注信息
        for f in glob.glob(path + "/*.xml"):
            # 解析xml文件
            try:
                tree = ET.parse(f)
            except FileNotFoundError:
                print("无法找到xml文件: "+f)
                return False
            except ET.ParseError:
                print("无法解析xml文件: "+f)
                return False
            else:  # ET.parse()正确运行

                # 取得xml根节点
                root = tree.getroot()

                # 取得图像路径和文件名
                img_name = root.find("filename").text
                img_path = root.find("path").text

                # 取得图像宽高
                img_width = int(root.find("size")[0].text)
                img_height = int(root.find("size")[1].text)

                # 取得所有标注object的信息
                label = []  # 类别名称
                xmin = []
                xmax = []
                ymin = []
                ymax = []

                # 查找根节点下全部名为object的节点
                for m in root.findall("object"):
                    xmin.append(int(m[4][0].text))
                    xmax.append(int(m[4][2].text))
                    ymin.append(int(m[4][1].text))
                    ymax.append(int(m[4][3].text))
                    # 用encode将str类型转为bytes类型,相应的用decode由bytes转回str类型
                    label.append(m[0].text.encode("utf-8"))

                # 至少有1个标注目标
                if len(label) > 0:
                    # 用OpenCV读出图像原始数据,未压缩数据
                    data = cv2.imread(img_path, cv2.IMREAD_COLOR)

                    # 将OpenCV的BGR格式转为RGB格式
                    data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)

                    # 建立Example
                    example = tf.train.Example(features=tf.train.Features(feature={
                        # 用encode将str类型转为bytes类型
                        # 以下各feature的shape固定,读出时必须使用tf.FixedLenFeature
                        "filename": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_name.encode("utf-8")])),
                        "width": tf.train.Feature(int64_list=tf.train.Int64List(value=[img_width])),
                        "height": tf.train.Feature(int64_list=tf.train.Int64List(value=[img_height])),
                        "data": tf.train.Feature(bytes_list=tf.train.BytesList(value=[data.tostring()])),  # 图像数据ndarray转化成bytes类型
                        # 以下各feature的shape不固定,,读出时必须使用tf.VarLenFeature
                        "object/label": tf.train.Feature(bytes_list=tf.train.BytesList(value=label)),
                        "object/bbox/xmin": tf.train.Feature(int64_list=tf.train.Int64List(value=xmin)),
                        "object/bbox/xmax": tf.train.Feature(int64_list=tf.train.Int64List(value=xmax)),
                        "object/bbox/ymin": tf.train.Feature(int64_list=tf.train.Int64List(value=ymin)),
                        "object/bbox/ymax": tf.train.Feature(int64_list=tf.train.Int64List(value=ymax))
                    }))

                    # 建立TFRecord的写对象
                    # img_name.split('.')[0]用于去掉扩展名,只保留文件名
                    with tf.python_io.TFRecordWriter(os.path.join(path, img_name.split('.')[0]+".tfrecords")) as writer:
                        # 数据写入TFRecord文件
                        writer.write(example.SerializeToString())

                        # 结束
                        print("生成TFRecord文件: " + os.path.join(path, img_name.split('.')[0]+".tfrecords"))
                else:
                    print("xml文件{0}无标注目标".format(f))
                    return False

        print("完成全部xml标注文件的保存")
        return True


if __name__ == "__main__":
    LabelImg2TFRecord.gen("d:/temp/test_folder")

验证

import tensorflow as tf
import cv2
import numpy as np


# 获取文件列表
files = tf.train.match_filenames_once("d:/temp/test_folder/cap*.tfrecords")

# 创建输入队列
# 重复输出输入文件列表中的所有文件名,除非用参数num_epochs指定每个文件可轮询的次数
# shuffle参数用于控制随机打乱文件排序
# 返回值说明:
# A queue with the output strings. A QueueRunner for the Queue is added to the current Graph's QUEUE_RUNNER collection.
# 用tf.train.start_queue_runners启动queue的输出
queue = tf.train.string_input_producer(files, shuffle=True)

# 建立TFRecordReader并解析TFRecord文件
reader = tf.TFRecordReader()
_, serialized_example = reader.read(queue)  # tf.TFRecordReader.read()用于读取queue中的下一个文件
rec_features = tf.parse_single_example(  # 返回字典,字典key值即features参数中的key值
    serialized_example,
    features={
        # 写入时shape固定的数值用FixedLenFeature
        "filename": tf.FixedLenFeature(shape=[], dtype=tf.string),  # 由于只有1个值也可以用shape[1],返回list
        "width": tf.FixedLenFeature(shape=[], dtype=tf.int64),
        "height": tf.FixedLenFeature(shape=[], dtype=tf.int64),
        "data": tf.FixedLenFeature(shape=[], dtype=tf.string),
        # 写入时shape不固定的数值,读出时用VarLenFeature,读出为SparseTensorValue类对象
        "object/label": tf.VarLenFeature(dtype=tf.string),
        "object/bbox/xmin": tf.VarLenFeature(dtype=tf.int64),
        "object/bbox/xmax": tf.VarLenFeature(dtype=tf.int64),
        "object/bbox/ymin": tf.VarLenFeature(dtype=tf.int64),
        "object/bbox/ymax": tf.VarLenFeature(dtype=tf.int64),
    }
)

# # 将tf.string转化成tf.uint8的tensor
# img_tensor = tf.decode_raw(rec_features["data"], tf.uint8)
# print(img_tensor.shape)  # 输出: dododo

with tf.Session() as sess:
    """
    sess.run(tf.global_variables_initializer())
    print(sess.run(files))
    上述代码运行出错,提示如下:
    Attempting to use uninitialized value matching_filenames
    因为tf.train.match_filenames_once使用的是局部变量,非全局变量
    需要改成下方代码才能正确运行
    """
    sess.run(tf.local_variables_initializer())
    print(sess.run(files))  # 打印文件列表

    # 用子线程启动tf.train.string_input_producer生成的queue
    coord = tf.train.Coordinator()  # 用于控制线程结束
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    # 读出TFRecord文件内容
    for i in range(20):
        # 每次run都由string_input_producer更新至下一个TFRecord文件
        rec = sess.run(rec_features)
        print(rec["filename"].decode("utf-8"))  # 由bytes类型转为str类型
        print("目标数目: " + str(rec["object/label"].values.size))
        print(rec["object/label"].values)
        print(rec["object/bbox/xmin"].values)
        print(rec["object/bbox/xmax"].values)
        print(rec["object/bbox/ymin"].values)
        print(rec["object/bbox/ymax"].values)
        
        # 将图像数据转化为numpy.ndarray
        img = np.fromstring(rec["data"], np.uint8)
        print(type(rec["data"]))  # 输出: 
        print(type(img))  # 输出: 

        # 根据feature设置图像shape
        img = np.reshape(img, (rec["height"], rec["width"], 3))
        print(img.shape)  # 输出: (rec["height"], rec["width"], 3)

        # 将图像由RGB转为RGB用于imshow
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

        # 绘制标注框
        for j in range(rec["object/label"].values.size):
            img = cv2.putText(img,
                              rec["object/label"].values[j].decode("utf-8"),
                              (rec["object/bbox/xmin"].values[j], rec["object/bbox/ymin"].values[j]-2),
                              cv2.FONT_HERSHEY_PLAIN,
                              1,
                              (0, 255, 0)
                              )
            img = cv2.rectangle(img,
                                (rec["object/bbox/xmin"].values[j], rec["object/bbox/ymin"].values[j]),
                                (rec["object/bbox/xmax"].values[j], rec["object/bbox/ymax"].values[j]),
                                (0, 0, 255))

        # 显示图像
        cv2.imshow(rec["filename"].decode("utf-8"), img)
        cv2.waitKey()

    coord.request_stop()  # 结束线程
    coord.join(threads)  # 等待线程结束

显示图像与LabelImg比较
从LabelImg建立TFRecord格式的训练样本_第6张图片
从LabelImg建立TFRecord格式的训练样本_第7张图片

你可能感兴趣的:(tensorflow)