使用mmrotate训练 Dronevehicle 数据集

使用mmrotate训练 Dronevehicle 数据集

目前看到有很多教程讲述mmrotate训练自制数据集或者dota数据集,这里讲一下在visitdrone数据集上的训练过程。

环境配置

和mmdetecion配置几乎一样,可以参考我的上一个博客。

数据集下载

DroneVehicle数据集地址:https://github.com/VisDrone/DroneVehicle
数据集格式:图片:jpg
标签:xml
类别:5类 ‘car’,‘truck’,‘bus’,‘van’,‘feright_car’
(官方首页的标签字母打错了,以标签文件中的类别为准)

Dronevehicle 数据集DOTA化

因为Dronevehicle 数据集是xml格式的,但是mmrotate不支持直接使用该格式进行训练,所以需要使用脚本进行转换,生成dota类型的txt标签文件

Dronevehicle 数据集格式:
使用mmrotate训练 Dronevehicle 数据集_第1张图片
可以看出该标签是以8点坐标来定义的旋转框。

下面来看一下标准的dota数据集的格式:

DOTA数据标签介绍
标注方式:oriented bounding box  定向边界框
x1, y1, x2, y2, x3, y3, x4, y4, category, difficult
x1, y1, x2, y2, x3, y3, x4, y4, category, difficult
...
x1, y1, x2, y2, x3, y3, x4, y4:四边形的四个顶点的坐标 顶点按顺时针顺序排列,第一个起点为左上第一个点
category:实例类别
difficult:表示该实例是否难以检测(1表示困难,0表示不困难)

使用python脚本将xml 文件中的关键数据写入到txt文件进行输出:

import os
import xml.etree.ElementTree as ET
import math
import cv2 as cv


def voc_to_dota(xml_dir, xml_name, img_dir, savedImg_dir):
    txt_name = xml_name[:-4] + '.txt'  # txt文件名字:去掉xml 加上.txt
    txt_path = xml_dir + '/txt_label'  # txt文件目录:在xml目录下创建的txtl_label文件夹
    if not os.path.exists(txt_path):
        os.makedirs(txt_path)
    txt_file = os.path.join(txt_path, txt_name)  # txt完整的含名文件路径

    img_name = xml_name[:-4] + '.jpg'  # 图像名字
    img_path = os.path.join(img_dir, img_name)  # 图像完整路径
    img = cv.imread(img_path)  # 读取图像

    xml_file = os.path.join(xml_dir, xml_name)
    tree = ET.parse(os.path.join(xml_file))  # 解析xml文件 然后转换为DOTA格式文件
    root = tree.getroot()
    with open(txt_file, "w+", encoding='UTF-8') as out_file:
        # out_file.write('imagesource:null' + '\n' + 'gsd:null' + '\n')
        for obj in root.findall('object'):
            name = obj.find('name').text
            # if name == 'car':
            #     name = name
            # else:
            #     name = 'car'
            if name == 'feright car':
                name = 'feright_car'
            else:
                name = name
            obj_difficult = obj.find('difficult')
            if obj_difficult:
                difficult = obj_difficult.text
            else:
                difficult ='0'
            # print(name, difficult)

            if obj.find('bndbox'):
                obj_bnd = obj.find('bndbox')
                obj_xmin = obj_bnd.find('xmin').text
                obj_ymin = obj_bnd.find('ymin').text
                obj_xmax = obj_bnd.find('xmax').text
                obj_ymax = obj_bnd.find('ymax').text
                # w = obj_xmax-obj_xmin
                # h = obj_ymax-obj_ymin
                x1 = obj_xmin
                y1 = obj_ymin
                x2 = obj_xmax
                y2 = obj_ymin
                x3 = obj_xmax
                y3 = obj_ymax
                x4 = obj_xmin
                y4 = obj_ymax
                
            elif obj.find('polygon'):
                obj_bnd = obj.find('polygon')
                x1 = obj_bnd.find('x1').text
                x2 = obj_bnd.find('x2').text
                x3 = obj_bnd.find('x3').text
                x4 = obj_bnd.find('x4').text
                y1 = obj_bnd.find('y1').text
                y2 = obj_bnd.find('y2').text
                y3 = obj_bnd.find('y3').text
                y4 = obj_bnd.find('y4').text
            # robndbox = obj.find('robndbox')
            # cx = float(robndbox.find('cx').text)
            # cy = float(robndbox.find('cy').text)
            # w = float(robndbox.find('w').text)
            # h = float(robndbox.find('h').text)
            # angle = float(robndbox.find('angle').text)
            # print(cx, cy, w, h, angle)
            # 找最左上角的点
            # 在原图上画矩形 看是否转换正确
            # cv.line(img, (int(list_xy[0]), int(list_xy[1])), (int(list_xy[2]), int(list_xy[3])), color=(255, 0, 0),
            #         thickness=3)
            # cv.line(img, (int(list_xy[2]), int(list_xy[3])), (int(list_xy[4]), int(list_xy[5])), color=(0, 255, 0),
            #         thickness=3)
            # cv.line(img, (int(list_xy[4]), int(list_xy[5])), (int(list_xy[6]), int(list_xy[7])), color=(0, 0, 255),
            #         thickness=2)
            # cv.line(img, (int(list_xy[6]), int(list_xy[7])), (int(list_xy[0]), int(list_xy[1])), color=(255, 255, 0),
            #         thickness=2)
            data = str(x1) + " " + str(y1) + " " + str(x2) + " " + str(y2) + " " + \
                   str(x3) + " " + str(y3) + " " + str(x4) + " " + str(y4) + " "
            data = data + name + " " + difficult + "\n"
            out_file.write(data)
        if not os.path.exists(savedImg_dir):
            os.makedirs(savedImg_dir)
        out_img = os.path.join(savedImg_dir, xml_name[:-4] + '.jpg')
        cv.imwrite(out_img, img)


def find_topLeftPopint(dict):
    dict_keys = sorted(dict.keys())  # y值
    temp = [dict[dict_keys[0]], dict[dict_keys[1]]]
    minx = min(temp)
    if minx == temp[0]:
        miny = dict_keys[0]
    else:
        miny = dict_keys[1]
    return [minx, miny]


# # 转换成四点坐标
# def rotatePoint(xc, yc, xp, yp, theta):
#     xoff = xp - xc
#     yoff = yp - yc
#     cosTheta = math.cos(theta)
#     sinTheta = math.sin(theta)
#     pResx = cosTheta * xoff + sinTheta * yoff
#     pResy = - sinTheta * xoff + cosTheta * yoff
#     # pRes = (xc + pResx, yc + pResy)
#     # 保留一位小数点
#     return float(format(xc + pResx, '.1f')), float(format(yc + pResy, '.1f'))
#     # return xc + pResx, yc + pResy


import argparse


def parse_args():
    parser = argparse.ArgumentParser(description='数据格式转换')
    parser.add_argument('--xml-dir', default=r'D:\data\dronevehicle\test\testlabel', help='original xml file dictionary')
    parser.add_argument('--img-dir', default=r'D:\data\dronevehicle\test\testimg', help='original image dictionary')
    parser.add_argument('--outputImg-dir', default=r'D:\data\dronevehicle\test\out',
                        help='saved image dictionary after dealing ')

    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    xml_path = args.xml_dir
    xmlFile_list = os.listdir(xml_path)
    print(xmlFile_list)
    for i in range(0, len(xmlFile_list)):
        if ('.xml' in xmlFile_list[i]) or ('.XML' in xmlFile_list[i]):
            voc_to_dota(xml_path, xmlFile_list[i], args.img_dir, args.outputImg_dir)
            print('----------------------------------------{}{}----------------------------------------'
                  .format(xmlFile_list[i], ' has Done!'))
        else:
            print(xmlFile_list[i] + ' is not xml file')


**注意:**这里的数据集转换有几个坑:
1、Dronevehicle 数据集中有旋转框和水平框,转换时要进行分类获取坐标
2、其中有个类别为feright car,由于mmrotate读取标签的机制,这个字符串在txt中不会直接被读取,因为中间有空格,会认为是两个标签,从而报错,需要转换为feright_car或者feright-car。让成为一个完整的字符串。

转化的txt结果格式为:

使用mmrotate训练 Dronevehicle 数据集_第2张图片

训练设置

1、新增数据集类别:在路径mmrotate/datasets/下复制dota.py 更换名字为drone.py
重新定义类函数,并在__init__.py中进行注册声明
在这里插入图片描述
init.py
在这里插入图片描述
将drone.py类别内容修改为:

   """
    # CLASSES = ('plane', 'baseball-diamond', 'bridge', 'ground-track-field',
    #            'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
    #            'basketball-court', 'storage-tank', 'soccer-ball-field',
    #            'roundabout', 'harbor', 'swimming-pool', 'helicopter')

    CLASSES = ('car','truck','bus','van','feright_car')

    PALETTE = [(165, 42, 42), (189, 183, 107), (0, 255, 0), (255, 0, 0), (138, 43, 226)]
    # CLASSES = ('car',)
    # PALETTE = [
    #     (0, 255, 0),
    # ]

2、在configs/base/datasets/下复制dotav1.py为drone.py,并修改其中的数据集类别以及数据集根目录
在这里插入图片描述
以及训练集,测试集,和验证集的相关路径(图片为jpg格式,标签为txt格式)

3、修改模型的配置
在configs/下有众多的模型,这里选择s2anet作为检测网络
使用mmrotate训练 Dronevehicle 数据集_第3张图片
复制s2anet_r50_fpn_1x_dota_le135.py为s2anet_r50_fpn_1x_drone_le135.py
修改其中的_base_中制定的数据集文件(drone.py)
使用mmrotate训练 Dronevehicle 数据集_第4张图片
在文件中搜索num_classes,并将设置为5类
在这里插入图片描述

开始训练

1、下载预训练权重模型
在项目官网上下载预训练权重模型
https://github.com/open-mmlab/mmrotate

2、在终端执行

python tools\train.py configs\s2anet\s2anet_r50_fpn_1x_dota_le135.py --work-dir work-dir\run\s2anet

使用mmrotate训练 Dronevehicle 数据集_第5张图片

后续

mmdetecion的检测算法是一个封装的很完整的工程项目,
优点:
封装程度高,意味着很好上手使用,
可选网络模型多,对比实验充足,
测试训练评价系统整体完善。
缺点:
高效的封装程度给改进模型带来更多的难题,并且有一些机制与传统的pytorch推理的还是有所不同,为二次开发带来了难题。

目标任务:
1、调整网络结构参数,提高检测度
2、熟悉mmdetecion结构,自定义模块改进网络
3、期待mmyolo支持旋转框检测

欢迎各位大佬一起交流,批评指正!!

你可能感兴趣的:(模式识别,pytorch,目标检测,python,开发语言)