自制mxnet的yolo v2的rec数据集

mxnet官方yolov2的教程: http://zh.gluon.ai/chapter_computer-vision/yolo.html

其rec制作没有交代。。。。本人解析了一下,其label格式如下:

[4,5,512,512,0,0.3,0.4,0.5,0.2]——>[头长度(表示4,...,512共4个数字都是头),目标描述长(id号+框的4个参数,共5个数字来描述目标信息),(512,512)训练时没用到,填什么都行,0是目标id,剩下4个数是框的信息,详见voc训练集的xml文件],xml文件的产生,请自行百度labelImg的用法。。。。

编了一个rec产生函数,大家可以参考

def gen_det_rec(classes,img_dir,ratio=1):
    import os, sys, cv2, random
    import xml.etree.ElementTree as ET
    file_names = os.listdir(img_dir)
    file_names.sort()
    file_num = int(len(file_names) / 2)
    if ratio<1 and ratio>0:
        idx_random = list(range(file_num))
        random.shuffle(idx_random)
        idx_train=idx_random[:int(file_num*ratio)+1]
        idx_val=idx_random[int(file_num*ratio)+1:]
    else:
        print('wrong ratio value!')


    record_train = mx.recordio.MXIndexedRecordIO('train.idx', 'train.rec', 'w')
    if idx_val:
        record_val = mx.recordio.MXIndexedRecordIO('val.idx', 'val.rec', 'w')


    for idx in range(file_num):
        img = cv2.imread(img_dir + file_names[2 * idx])
        label_file = img_dir + file_names[2 * idx + 1]
        tree = ET.parse(label_file)
        root = tree.getroot()
        size = root.find('size')
        width = float(size.find('width').text)
        height = float(size.find('height').text)
        label = []


        for obj in root.iter('object'):
            difficult = int(obj.find('difficult').text)
            # if not self.config['use_difficult'] and difficult == 1:
            #     continue
            cls_name = obj.find('name').text
            if cls_name not in classes:
                continue
            cls_id = classes.index(cls_name)
            xml_box = obj.find('bndbox')
            xmin = float(xml_box.find('xmin').text) / width
            ymin = float(xml_box.find('ymin').text) / height
            xmax = float(xml_box.find('xmax').text) / width
            ymax = float(xml_box.find('ymax').text) / height
            label.append([4, 5, width, height, cls_id, xmin, ymin, xmax, ymax])


        header = mx.recordio.IRHeader(0, label, cls_id, 0)
        packed_s = mx.recordio.pack_img(header, img)
        if idx in idx_train:
            record_train.write_idx(idx, packed_s)
        else:
            record_val.write_idx(idx, packed_s)
    record_train.close()
    if idx_val:
        record_val.close()

应用方法:

calsses=[‘xxxx’,‘dumy’]
img=’/home/yyy/imgs/'

ratio=0.9

gen_det_rec(classes,img_dir,ratio)



你可能感兴趣的:(自制mxnet的yolo v2的rec数据集)