mmdetection预测

mmdetection2.0版本相较于之前做了些改变,但官方文档有些小细节没来得及更新,所以在用mmdetection进行预测的时候会报错,很多博客还是原来1.x的版本,所以2.0具体可参照这篇方法:
https://blog.csdn.net/ruiying413/article/details/106477509

另外关于预测,网上都是一张图片或者几张图片的那种,没有可以读取整个文件夹的,所以自己写了一下,下面贴上预测单张和预测文件夹里面图片的代码,仅供参考:

预测单张:

from mmdet.apis import init_detector, inference_detector,show_result_pyplot
import numpy as np
import os
import cv2
import random
import mmcv


config_file = 'configs/ms_rcnn/ms_rcnn_r50_fpn_1x_coco.py'#改为自己要用的
checkpoint_file = 'work_dirs/ms_rcnn_r50_fpn_1x_coco/epoch_200.pth'#改为自己训练的模型

# build the model from a config file and a checkpoint file
model = init_detector(config_file, checkpoint_file, device='cuda:0')
test_img='demo/test.jpg'#改为自己测试的图片

# test a single image and show the results
#img = 'demo/test.jpg'  # or img = mmcv.imread(img), which will only load it once
img=test_img
result = inference_detector(model, img)
# visualize the results in a new window
model.show_result(img, result)
# or save the visualization results to image files
model.show_result(img, result, out_file='demo/result.jpg')
show_result_pyplot(model, img, result)

预测多张:

from mmdet.apis import init_detector, inference_detector,show_result_pyplot
import numpy as np
import os
import cv2
import random
import mmcv


config_file = 'configs/ms_rcnn/ms_rcnn_r50_fpn_1x_coco.py'
checkpoint_file = 'work_dirs/ms_rcnn_r50_fpn_1x_coco/epoch_200.pth'

# build the model from a config file and a checkpoint file
model = init_detector(config_file, checkpoint_file, device='cuda:0')

in_folder='data/coco/test2017/'
out_folder='data/coco/test_out2017/'

if not os.path.exists(out_folder):
    os.makedirs(out_folder)
    
for file_name in os.listdir(in_folder):
    img_path=os.path.join(in_folder,file_name)
    img=cv2.imread(img_path)
    

    # test a single image and show the results
    #img = 'demo/test.jpg'  # or img = mmcv.imread(img), which will only load it once
#    img=test_img
    result = inference_detector(model, img)
    # visualize the results in a new window
    model.show_result(img, result)
    # or save the visualization results to image files
    save_path=os.path.join(out_folder,file_name)
    model.show_result(img, result, out_file=save_path)
    show_result_pyplot(model, img, result)

你可能感兴趣的:(目标检测)