在上一篇文章中介绍了MMRotate的概述、安装和训练Dota数据集全流程,由于文章篇幅限制还剩下一部分模型的推理和部署环节没有写,为避免后续对这部分工作的遗忘,决定还是补充上这部分的笔记,仅作记录,如有不足之处还请指出!
可以首先使用官网上的推理方法,从源代码安装mmrotate, 只需运行以下命令。
# 首先把所需要的配置文件下载下来,或直接用你训练好的模型和配置。
mim download mmrotate --config oriented-rcnn-le90_r50_fpn_1x_dota --dest .
# 推理单张图片,也可验证mmrotate是否正确推理
python demo/image_demo.py demo/demo.jpg oriented_rcnn_r50_fpn_1x_dota_le90.py
oriented_rcnn_r50_fpn_1x_dota_le90-6d2b2ce0.pth --out-file result.jpg
执行成功之后会生成result.jpg,就是模型的推理结果。
参考地址:https://mmrotate.readthedocs.io/zh-cn/1.x/get_started.html
您可以使用以下命令来推理数据集进行模型测试。
# 单个 GPU
python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [optional arguments]
例如:
python ./tools/test.py oriented-rcnn-le90_r50_fpn_1x_dota.py oriented_rcnn_r50_fpn_1x_dota_le90-6d2b2ce0.pth
参考地址:https://mmrotate.readthedocs.io/zh-cn/stable/get_started.html#id2
通过上述方法,我们可以简单测试模型的性能,如果模型性能满足要求之后,我们想生产部署或者批量推理的时候,我们需要将模型的推理结果导出为xml文件,可以借助代码中的demo/huge_image_demo.py
获取模型的推理结果,提取结果中的标签信息,并将结果保存为xml文件。
修改后的完整推理代码如下:
run,py
# Copyright (c) OpenMMLab. All rights reserved.
import os
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import numpy as np
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
import xml.etree.ElementTree as ET
from xml.dom.minidom import parseString
from glob import glob
from tqdm import tqdm
from argparse import ArgumentParser
import cv2
from datetime import datetime
import mmcv
from mmdet.apis import init_detector
from mmrotate.apis import inference_detector_by_patches
from mmrotate.registry import VISUALIZERS
from mmrotate.utils import register_all_modules
def parse_args():
parser = ArgumentParser()
parser.add_argument('input_path', help='Image file')
parser.add_argument('--config', default='configs/rotated_rtmdet/rotated_rtmdet_l-100e-aug-dota.py', help='Config file')
parser.add_argument('--checkpoint', default='rtmdet_l_epoch_200.pth', help='Checkpoint file')
parser.add_argument('output_path', default='/output_path', help='Path to output file')
parser.add_argument(
'--patch_sizes',
type=int,
nargs='+',
default=[1024],
help='The sizes of patches')
parser.add_argument(
'--patch_steps',
type=int,
nargs='+',
default=[768], # 824
help='The steps between two patches')
parser.add_argument(
'--img_ratios',
type=float,
nargs='+',
default=[1.0],
help='Image resizing ratios for multi-scale detecting')
parser.add_argument(
'--merge_iou_thr',
type=float,
default=0.3,
help='IoU threshold for merging results')
parser.add_argument(
'--merge_nms_type',
default='nms_rotated',
choices=['nms', 'nms_rotated', 'nms_quadri'],
help='NMS type for merging results')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--palette',
default='dota',
choices=['dota', 'sar', 'hrsc', 'random'],
help='Color palette used for visualization')
parser.add_argument(
'--score-thr', type=float, default=0.3, help='bbox score threshold')
args = parser.parse_args()
return args
def remove_whitespace_nodes(node):
for child in list(node.childNodes):
if child.nodeType == child.TEXT_NODE and not child.data.strip():
node.removeChild(child)
elif child.hasChildNodes():
remove_whitespace_nodes(child)
def create_empty_xml(image_path, output_path):
root = ET.Element("annotation")
# Adding source details
source_x = ET.SubElement(root, "source")
file_n = ET.SubElement(source_x, "filename")
orig = ET.SubElement(source_x, "origin")
file_n.text = os.path.basename(image_path)
orig.text = "Optical"
# Adding research details
research = ET.SubElement(root, "re