MMrotate自定义数据集训练与验证&&格式转换脚本

  • 数据集准备

数据集格式

文件夹格式:Data/ #保存Dota数据集的目录

Train #存放images和labelTxt的文件夹

Images#存放所有训练集图片的文件夹

labelTxt #存放所有训练集txt标注文件的文件夹

LabelTxt中的txt文件可通过转换脚本ro_xml2txt.py将RolabelImg标注的xml转换成DOTA格式的txt文件。

其中--xml_dir为需要转换的存放xml的路径。

--output_dir为转换后的数据集存放路径。

  • 修改
  • mmrotate/configs/oriented_rcnn/oriented_rcnn_r50_fpn_1x_ROL_le90.py
  • MMrotate自定义数据集训练与验证&&格式转换脚本_第1张图片
  • MMrotate自定义数据集训练与验证&&格式转换脚本_第2张图片
  • MMrotate自定义数据集训练与验证&&格式转换脚本_第3张图片
  • MMrotate自定义数据集训练与验证&&格式转换脚本_第4张图片
  • 修改mmrotate/mmrotate/datasets/rolabel.py

 MMrotate自定义数据集训练与验证&&格式转换脚本_第5张图片

三、训练

训练命令格式:

# 单 GPU 训练

python tools/train.py ${CONFIG_FILE} [optional arguments]

# 多 GPU 训练

bash tools/dist_train.sh ${CONFIG_FILE} ${GPU_NUM} [optional arguments]

说明:

config_file:模型配置文件的路径

gpu_num:使用 GPU 的数量

--work-dir:设置存放训练生成文件的路径

--resume-from:设置恢复训练的模型检查点文件的路径

--no-validate(不建议):设置训练时不验证模型

--seed:设置随机种子,便于复现结果

这里以oriented_rcnn为例,cd 到yuml_web目录下,运行命令:

Python mmrotate/tools/train.py /

mmrotate/configs/oriented_rcnn/oriented_rcnn_r50_fpn_1x_ROL_le90.py

即可开始训练模型。其中训练产生的所有日志文件都保存在work_dir中。

  • 验证

# 单 GPU 测试

python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} \

    [--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}] [--show]

# 多 GPU 测试

bash tools/dist_test.sh ${CONFIG_FILE} ${CHECKPOINT_FILE} ${GPU_NUM} \

[--out ${RESULT_FILE}] [--eval ${EVAL_METRICS}]

config_file:模型配置文件的路径

checkpoint_file:模型检查点文件的路径

gpu_num:使用的 GPU 数量

--out:设置输出 pkl 测试结果文件的路径

--work-dir:设置存放 json 日志文件的路径

--eval:设置度量指标(voc:mAP, recall | coco:bbox, segm, proposal)

--show:设置显示有预测框的测试集图像

--show-dir:设置存放有预测框的测试集图像的路径

--show-score-thr:设置显示预测框的阈值,默认值为 0.3

--fuse-conv-bn: 设置融合卷积层和批归一化层,能够稍微提升推理速度

这里以oriented_rcnn为例,建议在work_dir中需要验证的pth模型文件复制到yuml_web/checkpoints/下,cd 到yuml_web目录下,运行命令:

python  mmrotate/tools/test.py  mmrotate/configs/oriented_rcnn/oriented_rcnn_r50_fpn_1x_ROL_le90.py  /

checkpoints/pth模型文件名 --show

注:验证的数据集为test.txt中的图片名

转换脚本:

# -*- coding: utf-8 -*-
# @Time : 2021/5/2 15:42
# @Author : Bob.Xu
# @Site :  根据rolabelimg标注的xml文件转换为txt格式文件
# @File : xml2txt.py
# @Software: PyCharm

# 将标记后的xml文件转为advanceeast训练的格式
import os
import xml.etree.ElementTree as ET
import shutil
import glob
import time
import math
import argparse



def rotatePoint(xc, yc, xp, yp, theta):
    '''
    xc:x中心点
    yc:y中心点
    xp:x边长度
    yp:y边长度
    thete:旋转角度
    '''
    xoff = xp - xc;
    yoff = yp - yc;

    cosTheta = math.cos(theta)
    sinTheta = math.sin(theta)
    pResx = cosTheta * xoff + sinTheta * yoff
    pResy = - sinTheta * xoff + cosTheta * yoff
    # pRes = (xc + pResx, yc + pResy)
    return str(xc + pResx), str(yc + pResy)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--xml_dir', type=str,default='yaodai_data/Annotations',
                        help='Directory of images and xml.')
    parser.add_argument('--output_dir', type=str,default='yaodai_data/labelTxt/',
                        help='Directory of output.')
    a=parser.parse_args()

    path = a.xml_dir
    file = os.listdir(path)
    file = glob.glob(path + "/*.xml")
    output_dir = a.output_dir
    n = 0
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for filename in file:
        # start=time.time()
        first = os.path.splitext(filename)[0]
        last = os.path.splitext(filename)[1]

        if last == ".xml":
            # print(first,last)
            next = first.split("/")
            name = next[-1] + ".xml"
            n = n + 1
            print("正在处理第{}个xml文件,名称为{}".format(n, name))
            filetxt = first + ".txt"
            f = open(filetxt, 'w', encoding='utf-8')

            aa = []
            tree = ET.parse(filename)
            root = tree.getroot()
            for tt in root.iter("object"):
                if tt.find('bndbox'):
                    lefttopx = tt.find("bndbox")[0].text
                    lefttopy = tt.find("bndbox")[1].text
                    righttopx = tt.find("bndbox")[2].text
                    righttopy = tt.find("bndbox")[1].text
                    rightdownx = tt.find("bndbox")[2].text
                    rightdowny = tt.find("bndbox")[3].text
                    leftdownx = tt.find("bndbox")[0].text
                    leftdowny = tt.find("bndbox")[3].text
                    tb = tt.find("name").text
                    df = tt.find("difficult").text
                    aa = list([lefttopx, lefttopy, righttopx, righttopy, rightdownx, rightdowny, leftdownx, leftdowny, tb, df])
                    bb = " ".join(aa)
                    f.writelines(bb)
                    f.writelines("\n")
                elif tt.find('robndbox'):
                    cx = float(tt.find("robndbox")[0].text)
                    cy = float(tt.find("robndbox")[1].text)
                    w = float(tt.find("robndbox")[2].text)
                    h = float(tt.find("robndbox")[3].text)
                    angle = float(tt.find("robndbox")[4].text)
                    lefttopx, lefttopy = rotatePoint(cx, cy, cx - w / 2, cy - h / 2, -angle)
                    righttopx, righttopy = rotatePoint(cx, cy, cx + w / 2, cy - h / 2, -angle)
                    rightdownx, rightdowny = rotatePoint(cx, cy, cx + w / 2, cy + h / 2, -angle)
                    leftdownx, leftdowny = rotatePoint(cx, cy, cx - w / 2, cy + h / 2, -angle)
                    tb = tt.find("name").text
                    df = tt.find("difficult").text
                    aa = list(
                        [lefttopx, lefttopy, righttopx, righttopy, rightdownx, rightdowny, leftdownx, leftdowny, tb, df])
                    bb = " ".join(aa)
                    f.writelines(bb)
                    f.writelines("\n")

        else:
            continue

替换:

/mmrotate/mmrotate/datasets/__init__.py

# Copyright (c) OpenMMLab. All rights reserved.
from .builder import build_dataset  # noqa: F401, F403
from .dota import DOTADataset  # noqa: F401, F403
from .hrsc import HRSCDataset  # noqa: F401, F403
from .pipelines import *  # noqa: F401, F403
from .sar import SARDataset  # noqa: F401, F403
from .rolabel import ROLDataset  # noqa: F401, F403

__all__ = ['SARDataset', 'DOTADataset', 'build_dataset', 'HRSCDataset', 'ROLDataset']
 

"/mmrotate/mmrotate/datasets/rolabel.py"

# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os
import os.path as osp
import re
import tempfile
import time
import zipfile
from collections import defaultdict
from functools import partial

import mmcv
import numpy as np
import torch
from mmcv.ops import nms_rotated
from mmdet.datasets.custom import CustomDataset

from mmrotate.core import obb2poly_np, poly2obb_np
from mmrotate.core.evaluation import eval_rbbox_map
from .builder import ROTATED_DATASETS


@ROTATED_DATASETS.register_module()
class ROLDataset(CustomDataset):
    """DOTA dataset for detection.

    Args:
        ann_file (str): Annotation file path.
        pipeline (list[dict]): Processing pipeline.
        version (str, optional): Angle representations. Defaults to 'oc'.
        difficulty (bool, optional): The difficulty threshold of GT.
    """
    class_dir = 'yuml_web/data/class.txt'
    with open(class_dir, "r", encoding="utf-8") as f:
        a = [i.strip() for i in f.readlines()]
    CLASSES = a

    def __init__(self,
                 ann_file,
                 pipeline,
                 version='oc',
                 difficulty=100,
                 image_type='.png',
                 **kwargs):
        self.version = version
        self.difficulty = difficulty
        self.image_type = image_type
        super(ROLDataset, self).__init__(ann_file, pipeline, **kwargs)

    def __len__(self):
        """Total number of samples of data."""
        return len(self.data_infos)

    def load_annotations(self, ann_folder):
        """
            Args:
                ann_folder: folder that contains DOTA v1 annotations txt files
        """
        cls_map = {c: i
                   for i, c in enumerate(self.CLASSES)
                   }  # in mmdet v2.0 label is 0-based
        ann_files = glob.glob(ann_folder + '/*.txt')
        data_infos = []
        if not ann_files:  # test phase
            ann_files = glob.glob(ann_folder + '/*'+self.image_type)
            for ann_file in ann_files:
                data_info = {}
                img_id = osp.split(ann_file)[1][:-4]
                img_name = img_id + self.image_type
                data_info['filename'] = img_name
                data_info['ann'] = {}
                data_info['ann']['bboxes'] = []
                data_info['ann']['labels'] = []
                data_infos.append(data_info)
        else:
            for ann_file in ann_files:
                data_info = {}
                img_id = osp.split(ann_file)[1][:-4]
                img_name = img_id + self.image_type
                data_info['filename'] = img_name
                data_info['ann'] = {}
                gt_bboxes = []
                gt_labels = []
                gt_polygons = []
                gt_bboxes_ignore = []
                gt_labels_ignore = []
                gt_polygons_ignore = []

                if os.path.getsize(ann_file) == 0:
                    continue

                with open(ann_file) as f:
                    s = f.readlines()
                    for si in s:
                        bbox_info = si.split()
                        poly = np.array(bbox_info[:8], dtype=np.float32)
                        try:
                            x, y, w, h, a = poly2obb_np(poly, self.version)
                        except:  # noqa: E722
                            continue
                        cls_name = bbox_info[8]
                        difficulty = int(bbox_info[9])
                        label = cls_map[cls_name]
                        if difficulty > self.difficulty:
                            pass
                        else:
                            gt_bboxes.append([x, y, w, h, a])
                            gt_labels.append(label)
                            gt_polygons.append(poly)

                if gt_bboxes:
                    data_info['ann']['bboxes'] = np.array(
                        gt_bboxes, dtype=np.float32)
                    data_info['ann']['labels'] = np.array(
                        gt_labels, dtype=np.int64)
                    data_info['ann']['polygons'] = np.array(
                        gt_polygons, dtype=np.float32)
                else:
                    data_info['ann']['bboxes'] = np.zeros((0, 5),
                                                          dtype=np.float32)
                    data_info['ann']['labels'] = np.array([], dtype=np.int64)
                    data_info['ann']['polygons'] = np.zeros((0, 8),
                                                            dtype=np.float32)

                if gt_polygons_ignore:
                    data_info['ann']['bboxes_ignore'] = np.array(
                        gt_bboxes_ignore, dtype=np.float32)
                    data_info['ann']['labels_ignore'] = np.array(
                        gt_labels_ignore, dtype=np.int64)
                    data_info['ann']['polygons_ignore'] = np.array(
                        gt_polygons_ignore, dtype=np.float32)
                else:
                    data_info['ann']['bboxes_ignore'] = np.zeros(
                        (0, 5), dtype=np.float32)
                    data_info['ann']['labels_ignore'] = np.array(
                        [], dtype=np.int64)
                    data_info['ann']['polygons_ignore'] = np.zeros(
                        (0, 8), dtype=np.float32)

                data_infos.append(data_info)

        self.img_ids = [*map(lambda x: x['filename'][:-4], data_infos)]
        return data_infos

    def _filter_imgs(self):
        """Filter images without ground truths."""
        valid_inds = []
        for i, data_info in enumerate(self.data_infos):
            if data_info['ann']['labels'].size > 0:
                valid_inds.append(i)
        return valid_inds

    def _set_group_flag(self):
        """Set flag according to image aspect ratio.

        All set to 0.
        """
        self.flag = np.zeros(len(self), dtype=np.uint8)

    def evaluate(self,
                 results,
                 metric='mAP',
                 logger=None,
                 proposal_nums=(100, 300, 1000),
                 iou_thr=0.5,
                 scale_ranges=None,
                 nproc=4):
        """Evaluate the dataset.

        Args:
            results (list): Testing results of the dataset.
            metric (str | list[str]): Metrics to be evaluated.
            logger (logging.Logger | None | str): Logger used for printing
                related information during evaluation. Default: None.
            proposal_nums (Sequence[int]): Proposal number used for evaluating
                recalls, such as recall@100, recall@1000.
                Default: (100, 300, 1000).
            iou_thr (float | list[float]): IoU threshold. It must be a float
                when evaluating mAP, and can be a list when evaluating recall.
                Default: 0.5.
            scale_ranges (list[tuple] | None): Scale ranges for evaluating mAP.
                Default: None.
            nproc (int): Processes used for computing TP and FP.
                Default: 4.
        """
        nproc = min(nproc, os.cpu_count())
        if not isinstance(metric, str):
            assert len(metric) == 1
            metric = metric[0]
        allowed_metrics = ['mAP']
        if metric not in allowed_metrics:
            raise KeyError(f'metric {metric} is not supported')
        annotations = [self.get_ann_info(i) for i in range(len(self))]
        eval_results = {}
        if metric == 'mAP':
            assert isinstance(iou_thr, float)
            mean_ap, _ = eval_rbbox_map(
                results,
                annotations,
                scale_ranges=scale_ranges,
                iou_thr=iou_thr,
                dataset=self.CLASSES,
                logger=logger,
                nproc=nproc)
            eval_results['mAP'] = mean_ap
        else:
            raise NotImplementedError

        return eval_results

    def merge_det(self, results, nproc=4):
        """Merging patch bboxes into full image.

        Args:
            results (list): Testing results of the dataset.
            nproc (int): number of process. Default: 4.
        """
        collector = defaultdict(list)
        for idx in range(len(self)):
            result = results[idx]
            img_id = self.img_ids[idx]
            splitname = img_id.split('__')
            oriname = splitname[0]
            pattern1 = re.compile(r'__\d+___\d+')
            x_y = re.findall(pattern1, img_id)
            x_y_2 = re.findall(r'\d+', x_y[0])
            x, y = int(x_y_2[0]), int(x_y_2[1])
            new_result = []
            for i, dets in enumerate(result):
                bboxes, scores = dets[:, :-1], dets[:, [-1]]
                ori_bboxes = bboxes.copy()
                ori_bboxes[..., :2] = ori_bboxes[..., :2] + np.array(
                    [x, y], dtype=np.float32)
                labels = np.zeros((bboxes.shape[0], 1)) + i
                new_result.append(
                    np.concatenate([labels, ori_bboxes, scores], axis=1))

            new_result = np.concatenate(new_result, axis=0)
            collector[oriname].append(new_result)

        merge_func = partial(_merge_func, CLASSES=self.CLASSES, iou_thr=0.1)
        if nproc <= 1:
            print('Single processing')
            merged_results = mmcv.track_iter_progress(
                (map(merge_func, collector.items()), len(collector)))
        else:
            print('Multiple processing')
            merged_results = mmcv.track_parallel_progress(
                merge_func, list(collector.items()), nproc)

        return zip(*merged_results)

    def _results2submission(self, id_list, dets_list, out_folder=None):
        """Generate the submission of full images.

        Args:
            id_list (list): Id of images.
            dets_list (list): Detection results of per class.
            out_folder (str, optional): Folder of submission.
        """
        if osp.exists(out_folder):
            raise ValueError(f'The out_folder should be a non-exist path, '
                             f'but {out_folder} is existing')
        os.makedirs(out_folder)

        files = [
            osp.join(out_folder, 'Task1_' + cls + '.txt')
            for cls in self.CLASSES
        ]
        file_objs = [open(f, 'w') for f in files]
        for img_id, dets_per_cls in zip(id_list, dets_list):
            for f, dets in zip(file_objs, dets_per_cls):
                if dets.size == 0:
                    continue
                bboxes = obb2poly_np(dets, self.version)
                for bbox in bboxes:
                    txt_element = [img_id, str(bbox[-1])
                                   ] + [f'{p:.2f}' for p in bbox[:-1]]
                    f.writelines(' '.join(txt_element) + '\n')

        for f in file_objs:
            f.close()

        target_name = osp.split(out_folder)[-1]
        with zipfile.ZipFile(
                osp.join(out_folder, target_name + '.zip'), 'w',
                zipfile.ZIP_DEFLATED) as t:
            for f in files:
                t.write(f, osp.split(f)[-1])

        return files

    def format_results(self, results, submission_dir=None, nproc=4, **kwargs):
        """Format the results to submission text (standard format for DOTA
        evaluation).

        Args:
            results (list): Testing results of the dataset.
            submission_dir (str, optional): The folder that contains submission
                files. If not specified, a temp folder will be created.
                Default: None.
            nproc (int, optional): number of process.

        Returns:
            tuple:

                - result_files (dict): a dict containing the json filepaths
                - tmp_dir (str): the temporal directory created for saving \
                    json files when submission_dir is not specified.
        """
        nproc = min(nproc, os.cpu_count())
        assert isinstance(results, list), 'results must be a list'
        assert len(results) == len(self), (
            f'The length of results is not equal to '
            f'the dataset len: {len(results)} != {len(self)}')
        if submission_dir is None:
            submission_dir = tempfile.TemporaryDirectory()
        else:
            tmp_dir = None

        print('\nMerging patch bboxes into full image!!!')
        start_time = time.time()
        id_list, dets_list = self.merge_det(results, nproc)
        stop_time = time.time()
        print(f'Used time: {(stop_time - start_time):.1f} s')

        result_files = self._results2submission(id_list, dets_list,
                                                submission_dir)

        return result_files, tmp_dir


def _merge_func(info, CLASSES, iou_thr):
    """Merging patch bboxes into full image.

    Args:
        CLASSES (list): Label category.
        iou_thr (float): Threshold of IoU.
    """
    img_id, label_dets = info
    label_dets = np.concatenate(label_dets, axis=0)

    labels, dets = label_dets[:, 0], label_dets[:, 1:]

    big_img_results = []
    for i in range(len(CLASSES)):
        if len(dets[labels == i]) == 0:
            big_img_results.append(dets[labels == i])
        else:
            try:
                cls_dets = torch.from_numpy(dets[labels == i]).cuda()
            except:  # noqa: E722
                cls_dets = torch.from_numpy(dets[labels == i])
            nms_dets, keep_inds = nms_rotated(cls_dets[:, :5], cls_dets[:, -1],
                                              iou_thr)
            big_img_results.append(nms_dets.cpu().numpy())
    return img_id, big_img_results
 

你可能感兴趣的:(python)