目标检测——数据集处理

目录

1.数据处理

1.1图片数据读取

1.1.1数据集划分

1.1.2真实框解析读取

1.2数据预处理

1.2.1成对读取图片及标号

 1.2.2 图片增广

1.3创建dataset类


1.数据处理

以百度的安全帽数据集为例。

1.1图片数据读取

1.1.1数据集划分

安全帽数据集共有5000张图片和5000个标注文件xml,每个xml文件对应一张图片,在提取数据集的标号前,首先应该划分数据集train、test、val各3750、625、625张,分别占全部数据集的1/4、1/8、1/8。

# 数据预处理:5000张图片和5000个标注xml文件
# 划分集合:train:3750, test:625, val:625
import os
import shutil
filenames = os.listdir('/home/aistudio/work/annotations')
print(len(filenames))

os.mkdir('/home/aistudio/work/annotations/train/')
os.mkdir('/home/aistudio/work/annotations/test/')
os.mkdir('/home/aistudio/work/annotations/val/')
for id, filename in enumerate(filenames):
    if id < 3750:
        shutil.move(r'/home/aistudio/work/annotations/'+ filename, r'/home/aistudio/work/annotations/train/'+filename)
    elif 3750 <= id < 4375:
        shutil.move(r'/home/aistudio/work/annotations/'+ filename, r'/home/aistudio/work/annotations/test/'+filename)
    else:
        shutil.move(r'/home/aistudio/work/annotations/'+ filename, r'/home/aistudio/work/annotations/val/'+filename)

1.1.2真实框解析读取

xml文件的内容:


    images
    hard_hat_workers0.png
    
        416
        416
        3
    
    0
    
        helmet
        Unspecified
        0
        0
        0
        
            357
            116
            404
            175
        
    

编写get_annotations函数,解析XML文件返回一个由字典组成的列表,每个字典的内容包括:(后面有#的为重点使用的信息)

voc_rec = {

            'im_file': img_file,  #图片文件路径名 

            'h': im_h,                # 图片的高

            'w': im_w,                #图片的宽

            'gt_class': gt_class,  # 真实框的类别 ,['helmet', 'head', 'person'] 共3类

            'gt_bbox': gt_bbox,  # 真实框的位置 

            'gt_poly': [],   

            'difficult': difficult

            }

get_annotations代码函数:3个参数cname2cid是类名序号的映射字典, datadir标注文件的目录, imgdir图片文件的目录

import xml.etree.ElementTree as ET
import numpy as np

def get_annotations(cname2cid, datadir, imgdir):
    # 列出文件名
    filenames = os.listdir(datadir)
    records = []
    for fname in filenames:
        # 解析xml
        fpath = os.path.join(datadir, fname)
        tree = ET.parse(fpath)
        
        # 解析图片
        img_file = os.path.join(imgdir, tree.find('filename').text)
        im_w = float(tree.find('size').find('width').text)
        im_h = float(tree.find('size').find('height').text)
        depth = int(tree.find('size').find('height').text)

        # 解析锚框标签
        objs = tree.findall('object')
        gt_bbox = np.zeros((len(objs), 4), dtype=np.float32)
        gt_class = np.zeros((len(objs), ), dtype=np.int32)
        is_crowd = np.zeros((len(objs), ), dtype=np.int32)
        difficult = np.zeros((len(objs), ), dtype=np.int32)
        
        for i, obj in enumerate(objs):
            # 这里cname2cid是名字--序号映射
            cname = obj.find('name').text
            gt_class[i] = cname2cid[cname]
            _difficult = int(obj.find('difficult').text)
            x1 = float(obj.find('bndbox').find('xmin').text)
            y1 = float(obj.find('bndbox').find('ymin').text)
            x2 = float(obj.find('bndbox').find('xmax').text)
            y2 = float(obj.find('bndbox').find('ymax').text)
            x1 = max(0, x1)
            y1 = max(0, y1)
            x2 = min(im_w - 1, x2)
            y2 = min(im_h - 1, y2)
            # 这里使用xywh格式来表示目标物体真实框
            # 公式:center_x = (x1 + x2)/2, center_y = (y1 + y2)/2, w = x2 - x1 +1, h = y2 - y1 + 1
            gt_bbox[i] = [(x1+x2)/2.0 , (y1+y2)/2.0, x2-x1+1., y2-y1+1.]
            is_crowd[i] = 0
            difficult[i] = _difficult

        voc_rec = {
            'im_file': img_file,
            'h': im_h,
            'w': im_w,
            'is_crowd': is_crowd,
            'gt_class': gt_class,
            'gt_bbox': gt_bbox,
            'gt_poly': [],
            'difficult': difficult
            }
        if len(objs) != 0:
            records.append(voc_rec)
    return records 
  

用train测试,得到records字典,查看列表长度,第二个字典的内容。

TRAINDIR = '/home/aistudio/work/annotations/train'
TESTDIR = '/home/aistudio/work/annotations/test'
VALIDDIR = '/home/aistudio/work/annotations/val'
IMGDIR = '/home/aistudio/work/images'
cname2cid = {
    'helmet': 0, 
    'head': 1, 
    'person': 2
    }
records = get_annotations(cname2cid, TRAINDIR, IMGDIR)
print(len(records))
print(records[1])

3750
{'im_file': '/home/aistudio/work/image/hard_hat_workers3295.png',
 'h': 415.0,
 'w': 416.0,
 'is_crowd': array([0, 0, 0, 0], dtype=int32),
 'gt_class': array([0, 0, 0, 0], dtype=int32),
 'gt_bbox': array([[111.5, 127.5,  76. ,  72. ],
        [293.5, 134. ,  90. ,  87. ],
        [111.5,  55.5,  76. ,  72. ],
        [293.5,  48. ,  90. ,  87. ]], dtype=float32),
 'gt_poly': [],
 'difficult': array([0, 0, 0, 0], dtype=int32)}

1.2数据预处理

1.2.1成对读取图片及标号

前面已经将图片的所有描述信息保存在records中了,其中每一个元素都包含了一张图片的描述,下面的程序展示了如何根据records里面的描述读取图片及标注。

就结果而言,我们读一条记录返回以下数据:图片img, 真实框位置信息gt_boxes, 真实框类别gt_labels, 图片大小(h, w)。之所以返回图片大小是因为真实框的位置信息转为了相对位置(既除以自己的宽或者高得到的相对值)

首先,由于每张图片的物体个数不同,这里统一为一张图片最多10个物体,不足的补0,多的舍弃。

# 读取图片及标注
import cv2
def get_bbox(gt_bbox, gt_class):
    # 对于一般的检测任务来说,一张图片上往往会有多个目标物体
    # 设置参数MAX_NUM = 10, 即一张图片最多取10个真实框;如果真实
    # 框的数目少于10个,则将不足部分的gt_bbox, gt_class和gt_score的各项数值全设置为0
    MAX_NUM = 10
    gt_bbox2 = np.zeros((MAX_NUM, 4))
    gt_class2 = np.zeros((MAX_NUM,))
    for i in range(len(gt_bbox)):
        gt_bbox2[i, :] = gt_bbox[i, :]
        gt_class2[i] = gt_class[i]
        if i >= MAX_NUM - 1:
            break
    return gt_bbox2, gt_class2
# 根据records列表的record字典返回一张图片数据及标注
def get_img_data_from_file(record):
    """
    record is a dict as following,
      record = {
            'im_file': img_file,
            'h': im_h,
            'w': im_w,
            'is_crowd': is_crowd,
            'gt_class': gt_class,
            'gt_bbox': gt_bbox,
            'gt_poly': [],
            'difficult': difficult
            }
    """
    im_file = record['im_file']
    h = record['h']
    w = record['w']
    is_crowd = record['is_crowd']
    gt_class = record['gt_class']
    gt_bbox = record['gt_bbox']
    difficult = record['difficult']

    img = cv2.imread(im_file)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # check if h and w in record equals that read from img
    assert img.shape[0] == int(h), \
             "image height of {} inconsistent in record({}) and img file({})".format(
               im_file, h, img.shape[0])

    assert img.shape[1] == int(w), \
             "image width of {} inconsistent in record({}) and img file({})".format(
               im_file, w, img.shape[1])

    gt_boxes, gt_labels = get_bbox(gt_bbox, gt_class)

    # gt_bbox 用相对值
    gt_boxes[:, 0] = gt_boxes[:, 0] / float(w)
    gt_boxes[:, 1] = gt_boxes[:, 1] / float(h)
    gt_boxes[:, 2] = gt_boxes[:, 2] / float(w)
    gt_boxes[:, 3] = gt_boxes[:, 3] / float(h)
  
    return img, gt_boxes, gt_labels, (h, w)

 1.2.2 图片增广

参考:飞桨PaddlePaddle-源于产业实践的开源深度学习平台

实现图片的填充、裁剪、缩放、翻转等。对于图片大小不统一的可在这里用缩放统一起来。

1.3创建dataset类

# 定义数据读取类,继承Paddle.io.Dataset
class TrainDataset(paddle.io.Dataset):
    def  __init__(self, datadir, imgdir, mode='train'):
        self.datadir = datadir
        cname2cid = cname2cid = {
            'helmet': 0, 
            'head': 1, 
            'person': 2
            }
        self.records = get_annotations(cname2cid, datadir, imgdir=imgdir)
        self.img_size = 448   #get_img_size(mode)

    def __getitem__(self, idx):
        record = self.records[idx]
        img, gt_bbox, gt_labels, im_shape = get_img_data(record, size=self.img_size)

        return img, gt_bbox, gt_labels, np.array(im_shape)

    def __len__(self):
        return len(self.records)

使用Dataloder加速读取

# 创建数据读取类
train_dataset = TrainDataset(TRAINDIR, imgdir=IMGDIR, mode='train')

# 使用paddle.io.DataLoader创建数据读取器,并设置batchsize,进程数量num_workers等参数
train_loader = paddle.io.DataLoader(train_dataset, batch_size=2, shuffle=True, drop_last=True)

你可能感兴趣的:(目标检测,目标检测,百度,深度学习)