在进行目标检测算法的学习过程中,需要进行对比实验,这里可以直接使用MMDetection框架来完成,该框架集成了许多现有的目标检测算法,方便我们进行对比实验。
首先是环境配置,先前博主曾经有过相关方面的配置,这里就简要记录一下:
创建conda环境:
conda create --name openmmlab python=3.7 -y
conda activate openmmlab
安装pytorch
pip install torch==1.7.0+cu110 torchvision==0.8.1+cu110 torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
安装pytorch时推荐使用pip安装,否则会报错:
OSError: /home/ubuntu/.conda/envs/mmdet/lib/python3.7/site-packages/torch/lib/../../../../libcublas.so.11: symbol free_gemm_select version libcublasLt.so.11 not defined in file libcublasLt.so.11 with link time reference
安装mmcv-full,安装MMCV需要对应CUDA和torch,安装命令要符合下面格式:
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/{cu_version}/{torch_version}/index.html
如安装CUDA11.0,pytorch=1.7.0:
pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.7.0/index.h
安装mmdet
pip install mmdet -i https://pypi.tuna.tsinghua.edu.cn/simple
运行报错:
File "/home/ubuntu/.conda/envs/mmdet/lib/python3.7/site-packages/mmdet/__init__.py", line 18, in <module>
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
AssertionError: MMCV==1.7.1 is used but incompatible. Please install mmcv>=2.0.0rc4, <2.1.0.
mmcv版本不正确,要求安装2.0.0,进入下面网址:
https://mmcv.readthedocs.io/en/latest/get_started/installation.html
随后根据版本选择mmcv版本:
pip install mmcv==2.0.0rc4 -f https://download.openmmlab.com/mmcv/dist/cu110/torch1.7/index.html
运行一下demo/image_demo.py,修改input(输入图片),model(配置文件),weights(权重文件)三个参数即可。
from argparse import ArgumentParser
from mmengine.logging import print_log
from mmdet.apis import DetInferencer
def parse_args():
parser = ArgumentParser()
parser.add_argument(
'--inputs', type=str,default="/home/ubuntu/programs/mmdetection/images/000000263594.jpg", help='Input image file or folder path.')
parser.add_argument(
'--model', type=str,default="/home/ubuntu/programs/mmdetection/configs/faster_rcnn/faster-rcnn_r50_fpn_2x_coco.py",
help='Config or checkpoint .pth file or the model name '
'and alias defined in metafile. The model configuration '
'file will try to read from .pth if the parameter is '
'a .pth weights file.')
parser.add_argument('--weights', default="/home/ubuntu/programs/mmdetection/weights/faster_rcnn_r50_fpn_2x_coco.pth", help='Checkpoint file')
parser.add_argument(
'--out-dir',
type=str,
default='/home/ubuntu/programs/mmdetection/outputs/',
help='Output directory of images or prediction results.')
parser.add_argument('--texts', help='text prompt')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--pred-score-thr',
type=float,
default=0.8,
help='bbox score threshold')
parser.add_argument(
'--batch-size', type=int, default=1, help='Inference batch size.')
parser.add_argument(
'--show',
action='store_true',
help='Display the image in a popup window.')
parser.add_argument(
'--no-save-vis',
action='store_true',
help='Do not save detection vis results')
parser.add_argument(
'--no-save-pred',
action='store_true',
help='Do not save detection json results')
parser.add_argument(
'--print-result',
action='store_true',
help='Whether to print the results.')
parser.add_argument(
'--palette',
default='none',
choices=['coco', 'voc', 'citys', 'random', 'none'],
help='Color palette used for visualization')
# only for GLIP
parser.add_argument(
'--custom-entities',
'-c',
action='store_true',
help='Whether to customize entity names? '
'If so, the input text should be '
'"cls_name1 . cls_name2 . cls_name3 ." format')
call_args = vars(parser.parse_args())
if call_args['no_save_vis'] and call_args['no_save_pred']:
call_args['out_dir'] = ''
if call_args['model'].endswith('.pth'):
print_log('The model is a weight file, automatically '
'assign the model to --weights')
call_args['weights'] = call_args['model']
call_args['model'] = None
init_kws = ['model', 'weights', 'device', 'palette']
init_args = {}
for init_kw in init_kws:
init_args[init_kw] = call_args.pop(init_kw)
return init_args, call_args
def main():
init_args, call_args = parse_args()
# TODO: Video and Webcam are currently not supported and
# may consume too much memory if your input folder has a lot of images.
# We will be optimized later.
inferencer = DetInferencer(**init_args)
inferencer(**call_args)
if call_args['out_dir'] != '' and not (call_args['no_save_vis']
and call_args['no_save_pred']):
print_log(f'results have been saved at {call_args["out_dir"]}')
if __name__ == '__main__':
main()
将训练好的权重文件下载完成后,对应好配置文件即可,权重文件在github中可以找到,如faster-r-cnn的文件,我们进入到config/faster-r-cnn后可以看到许多版本的faster-rcnn:
https://github.com/open-mmlab/mmdetection/tree/main/configs/faster_rcnn
然后下面有对应的权重文件,我们将其下载后即可完成推理过程。
推理结果如下: