Tensorflow2.0 YOLO篇之提取xml文件信息

Tensorflow2.0 YOLO篇之提取xml文件信息


  • YOLO篇之算法原理介绍
  • YOLO篇之提取xml文件信息
  • YOLO篇之图像信息预处理
  • YOLO篇之YOLO1论文
  • YOLO篇之模型搭建与训练

数据集介绍

数据集下载地址:

链接:https://pan.baidu.com/s/1ZP9H2ym3Vp4Sda1mNiv9Pw 
提取码:5okb 
复制这段内容后打开百度网盘手机App,操作更方便哦

这次选择的数据集是甜菜(sugarbeet)和杂草(weed)的数据集
Tensorflow2.0 YOLO篇之提取xml文件信息_第1张图片
在数据集的xml文件中包含了图片中物体的位置形状(x,y,w,h)和label
其中的一个xml文件


	train
	X2-10-1.png
	
		Unknown
	
	
		512
		512
		3
	
	0
	
		weed
		Unspecified
		0
		0
		
			71
			265
			115
			278
		
	
	
	......

	
		sugarbeet
		Unspecified
		0
		0
		
			322
			266
			363
			294
		
	

现在我们要做的工作就是将这些数据储存到numpy数组中去,代码中我尽可能的写了注释,书写这个的是否选择了vscode作为编译工具,以为vscode对于jupyter的支持较好,可以在编写的过程中更加方便的查看每一步的运行结果

同时在编写这一步的时候需要注意的一个点就是每个图片中的物体个数可能不一样,这样我们的boxes的个数就有问题。因为每个图片中的框信息都没有超过5个(上图除外那是我自己画的),所以我们每一张图片都涉及有五个空,不足的就用0来填充

#%%
import tensorflow as tf
import os,glob
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
from tensorflow import keras
# set random seed
tf.random.set_seed(2233)
np.random.seed(2233)

# %%
print(tf.__version__)
print(tf.text.is_gpu_available())



# %%
import xml.etree.ElementTree as ET

def parse_annotation(img_dir,ann_dir,labels):
    # parse annotation and save is into numpy array
    # img_dir: image path
    # ann_dir: annotation xml file path
    # labels: ('sugarweet','weed')
    imgs_info =[]
    # for each annotation xml file
    max_boxes = 0
    for ann in os.listdir(ann_dir):
        tree = ET.parse(os.path.join(ann_dir,ann))

        img_info = dict()
        img_info['object'] = []
        boxes_counter = 0
        for elem in tree.iter():
            if 'filename' in elem.tag:
                img_info['filename'] = os.path.join(img_dir,elem.text)
            if 'width' in elem.tag:
                img_info['width'] = int(elem.text)
                assert img_info['width'] == 512
            if 'height' in elem.tag:
                img_info['height'] = int(elem.text)
                assert img_info['width'] == 512
            if 'object' in elem.tag or 'part' in elem.tag:
                # x1-y1-x2-y2-label
                object_info =  [0,0,0,0,0]
                boxes_counter += 1
                for attr in list(elem):
                    # add image info into object_info
                    if 'name' in attr.tag:
                        label = labels.index(attr.text) + 1
                        object_info[4] = label
                    if 'bndbox' in attr.tag:
                        for pos in list(attr):
                            if 'xmin' in pos.tag:
                                object_info[0] = int(pos.text)
                            if 'ymin' in pos.tag:
                                object_info[1] = int(pos.text)
                            if 'xmax' in pos.tag:
                                object_info[2] = int(pos.text)
                            if 'ymax' in pos.tag:
                                object_info[3] = int(pos.text)
                img_info['object'].append(object_info)
        imgs_info.append(img_info) # filename,w/h/box_info
        # (N,5) = (max_objects_num,5) 5 is x-y-w-h-label
        if boxes_counter > max_boxes:
            max_boxes = boxes_counter
    # the maximum boxes number is max_boxes
    # [b,max_things,5]
    boxes = np.zeros([len(imgs_info),max_boxes,5])
    imgs = [] # filename last
    for i,img_info in enumerate(imgs_info):
        # [N,5] N: boxes number
        img_boxes = np.array(img_info['object'])
        # overwrite the N boxes info
        boxes[i,:img_boxes.shape[0]] = img_boxes
        imgs.append(img_info['filename'])
        print(img_info['filename'],boxes[i,:5])
    # imgs: list of image path
    # boxes:[b,40,5]
    return imgs,boxes

# %%
obj_names = ('sugarbeet','weed')
imgs,boxes = parse_annotation('data/train/image','data/train/annotation',obj_names)

参考书籍: TensorFlow 深度学习 — 龙龙老师

你可能感兴趣的:(小白的ai学习之路,TF2,算法,python,tensorflow,深度学习)