gluoncv 目标检测,训练自己的数据集

https://gluon-cv.mxnet.io/build/examples_datasets/detection_custom.html

官方提供两种方案,一种是lst文件,一种是xml文件(voc的格式);

voc 格式的标注有标注工具,但是你如果是json文件标注的信息,或者其他格式的,你就要转成voc格式的。

于是就选择第一种数据格式lst序列文件格式,格式很简单。

根据你自己的json或者其他格式文件转换一下。

import json
import os
import cv2
import numpy as np


def write_line(img_path, im_shape, boxes, ids, idx):
    h, w, c = im_shape
    # for header, we use minimal length 2, plus width and height
    # with A: 4, B: 5, C: width, D: height
    A = 4
    B = 5
    C = w
    D = h
    # concat id and bboxes
    labels = np.hstack((ids.reshape(-1, 1), boxes)).astype('float')
    # normalized bboxes (recommanded)
    labels[:, (1, 3)] /= float(w)
    labels[:, (2, 4)] /= float(h)
    # flatten
    labels = labels.flatten().tolist()
    str_idx = [str(idx)]
    str_header = [str(x) for x in [A, B, C, D]]
    str_labels = [str(x) for x in labels]
    str_path = [img_path]
    line = '\t'.join(str_idx + str_header + str_labels + str_path) + '\n'
    return line


files = os.listdir('train_front')
json_url = []
cnt = 0
for file in files:
    tmp = os.listdir('train_front/'+file)
    for js in tmp:
        if js.endswith('json'):
            json_url.append('train_front/'+file+'/'+js)
            cnt+=1
print(cnt)

fwtrain = open("train.lst","w")
fwval = open("val.lst","w")

first_flag = []
flag = True

cnt = 0
cnt1 = 0
cnt2 = 0
for json_url_index in json_url:
    file = open(json_url_index,'r')
    for line in file:
        js = json.loads(line)

        if 'person' in js:
            boxes = []
            ids = []
            for i in range(len(js['person'])):
                if js['person'][i]['attrs']['ignore'] == 'yes' or js['person'][i]['attrs']['occlusion']== 'heavily_occluded' or js['person'][i]['attrs']['occlusion']== 'invisible':
                    continue


                bbox = js['person'][i]['data']
                url = '/mnt/hdfs-data-4/data/jian.yin/'+json_url_index[:-5]+'/'+js['image_key']
                width = js['width']
                height = js['height']
                boxes.append(bbox)
                ids.append(0)

                print(url)
                print(bbox)

            if len(boxes) > 0:
                if flag:
                    flag = False
                    first_flag = boxes
                ids = np.array(ids)

                if cnt < 27853//2:

                    line = write_line(url,(height,width,3),boxes,ids,cnt1)
                    fwtrain.write(line)
                    cnt1+=1

                if cnt >= 27853//2:
                    line = write_line(url, (height, width, 3), boxes, ids, cnt2)
                    fwval.write(line)
                    cnt2+=1

                cnt += 1

fwtrain.close()
fwval.close()
print(first_flag)

lst文件就转换好了。

 

然后添加自己的数据集:

https://github.com/dmlc/gluon-cv/blob/master/scripts/detection/faster_rcnn/train_faster_rcnn.py#L73

这里不能直接套用前面的导入数据的过程。

按照教程给出的方式添加。投机取巧的验证方式,直接引用前面的。

或者不验证:https://github.com/dmlc/gluon-cv/blob/master/scripts/detection/faster_rcnn/train_faster_rcnn.py#L393 部分注释掉。

    elif dataset.lower() == 'pedestrian':
        lst_dataset = LstDetection('train_val.lst',root=os.path.expanduser('.'))
        print(len(lst_dataset))
        first_img = lst_dataset[0][0]

        print(first_img.shape)
        print(lst_dataset[0][1])
        
        train_dataset = LstDetection('train.lst',root=os.path.expanduser('.'))
        val_dataset = LstDetection('val.lst',root=os.path.expanduser('.'))
        classs = ('pedestrian',)
        val_metric = VOC07MApMetric(iou_thresh=0.5,class_names=classs)

训练参数:

https://github.com/dmlc/gluon-cv/blob/master/scripts/detection/faster_rcnn/train_faster_rcnn.py#L73

添加自己的训练参数或者直接套用。

    if args.dataset == 'voc' or args.dataset == 'pedestrian':
        args.epochs = int(args.epochs) if args.epochs else 20
        args.lr_decay_epoch = args.lr_decay_epoch if args.lr_decay_epoch else '14,20'
        args.lr = float(args.lr) if args.lr else 0.001
        args.lr_warmup = args.lr_warmup if args.lr_warmup else -1
        args.wd = float(args.wd) if args.wd else 5e-4

model_zoo.py添加自己的数据集映射方案。这里如果是pip install gluoncv ,就要到site-package里面改。

https://github.com/dmlc/gluon-cv/blob/master/gluoncv/model_zoo/model_zoo.py#L32

'faster_rcnn_resnet50_v1b_pedestrian': faster_rcnn_resnet50_v1b_voc,

 

转载于:https://www.cnblogs.com/TreeDream/p/10174899.html

你可能感兴趣的:(gluoncv 目标检测,训练自己的数据集)