在百度的PaddleDetection项目的基础上实现目标检测labelme标注的自动获取,需要先训练一个模型,然后通过这个模型去标注,最后用labelme进行微调
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import json
import io
import base64
# add python path of PadleDetection to sys.path
from ppdet.data.source.category import get_categories
from ppdet.optimizer import ModelEMA
from ppdet.utils.checkpoint import load_pretrain_weight
from ppdet.utils.visualizer import visualize_results
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
import warnings
warnings.filterwarnings('ignore')
import glob
import paddle
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.check import check_gpu, check_npu, check_version, check_config
from ppdet.utils.cli import ArgsParser
from ppdet.slim import build_slim_model
from ppdet.metrics import get_infer_results
from PIL import Image, ImageOps
import numpy as np
from ppdet.utils.logger import setup_logger
logger = setup_logger('train')
def parse_args():
parser = ArgsParser()
parser.add_argument(
"--infer_dir",
type=str,
default=r"C:\Users\86187\Desktop\classifyUnrecognizedESAll",
help="Directory for images to perform inference on.")
parser.add_argument(
"--infer_img",
type=str,
default=None,
help="Image path, has higher priority over --infer_dir")
parser.add_argument(
"--output_dir",
type=str,
default="python_infer_output",
help="Directory for storing the output visualization files.")
parser.add_argument(
"--draw_threshold",
type=float,
default=0.5,
help="Threshold to reserve the result for visualization.")
parser.add_argument(
"--slim_config",
default=None,
type=str,
help="Configuration file of slim method.")
parser.add_argument(
"--use_vdl",
type=bool,
default=False,
help="Whether to record the data to VisualDL.")
parser.add_argument(
'--vdl_log_dir',
type=str,
default="vdl_log_dir/image",
help='VisualDL logging directory for image.')
parser.add_argument(
"--save_txt",
type=bool,
default=False,
help="Whether to save inference result in txt.")
args = parser.parse_args()
return args
def img_arr_to_b64(img_pil):
# img_pil = Image.fromarray(img_arr)
f = io.BytesIO()
img_pil.save(f, format="PNG")
img_bin = f.getvalue()
if hasattr(base64, "encodebytes"):
img_b64 = base64.encodebytes(img_bin)
else:
img_b64 = base64.encodestring(img_bin)
return img_b64
def _get_save_image_name(output_dir, image_path):
"""
Get save image name from source image path.
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
image_name = os.path.split(image_path)[-1]
name, ext = os.path.splitext(image_name)
return os.path.join(output_dir, "{}".format(name)) + ext
def get_test_images(infer_dir, infer_img):
"""
Get image path list in TEST mode
"""
assert infer_img is not None or infer_dir is not None, \
"--infer_img or --infer_dir should be set"
assert infer_img is None or os.path.isfile(infer_img), \
"{} is not a file".format(infer_img)
assert infer_dir is None or os.path.isdir(infer_dir), \
"{} is not a directory".format(infer_dir)
# infer_img has a higher priority
if infer_img and os.path.isfile(infer_img):
return [infer_img]
images = set()
infer_dir = os.path.abspath(infer_dir)
assert os.path.isdir(infer_dir), \
"infer_dir {} is not a directory".format(infer_dir)
exts = ['jpg', 'jpeg', 'png', 'bmp']
exts += [ext.upper() for ext in exts]
for ext in exts:
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
images = list(images)
assert len(images) > 0, "no image found in {}".format(infer_dir)
logger.info("Found {} inference images in total.".format(len(images)))
return images
def _coco17_category():
"""
Get class id to category id map and category id
to category name map of COCO2017 dataset
"""
clsid2catid = {
1: 1,
2: 2,
3: 3,
4: 4,
5: 5,
6: 6,
7: 7,
8: 8,
9: 9,
10: 10,
11: 11,
12: 13,
13: 14,
14: 15,
15: 16,
16: 17,
17: 18,
18: 19,
19: 20,
20: 21,
21: 22,
22: 23,
23: 24,
24: 25,
25: 27,
26: 28,
27: 31,
28: 32,
29: 33,
30: 34,
31: 35,
32: 36,
33: 37,
34: 38,
35: 39,
36: 40,
37: 41,
38: 42,
39: 43,
40: 44,
41: 46,
42: 47,
43: 48,
44: 49,
45: 50,
46: 51,
47: 52,
48: 53,
49: 54,
50: 55,
51: 56,
52: 57,
53: 58,
54: 59,
55: 60,
56: 61,
57: 62,
58: 63,
59: 64,
60: 65,
61: 67,
62: 70,
63: 72,
64: 73,
65: 74,
66: 75,
67: 76,
68: 77,
69: 78,
70: 79,
71: 80,
72: 81,
73: 82,
74: 84,
75: 85,
76: 86,
77: 87,
78: 88,
79: 89,
80: 90
}
catid2name = {
0: 'background',
1: 'barCode',
2: '5',
3: '4',
4: '8',
5: '6',
6: '9',
7: '3',
8: '0',
9: '2',
10: '1',
11: 'fire hydrant',
13: 'stop sign',
14: 'parking meter',
15: 'bench',
16: 'bird',
17: 'cat',
18: 'dog',
19: 'horse',
20: 'sheep',
21: 'cow',
22: 'elephant',
23: 'bear',
24: 'zebra',
25: 'giraffe',
27: 'backpack',
28: 'umbrella',
31: 'handbag',
32: 'tie',
33: 'suitcase',
34: 'frisbee',
35: 'skis',
36: 'snowboard',
37: 'sports ball',
38: 'kite',
39: 'baseball bat',
40: 'baseball glove',
41: 'skateboard',
42: 'surfboard',
43: 'tennis racket',
44: 'bottle',
46: 'wine glass',
47: 'cup',
48: 'fork',
49: 'knife',
50: 'spoon',
51: 'bowl',
52: 'banana',
53: 'apple',
54: 'sandwich',
55: 'orange',
56: 'broccoli',
57: 'carrot',
58: 'hot dog',
59: 'pizza',
60: 'donut',
61: 'cake',
62: 'chair',
63: 'couch',
64: 'potted plant',
65: 'bed',
67: 'dining table',
70: 'toilet',
72: 'tv',
73: 'laptop',
74: 'mouse',
75: 'remote',
76: 'keyboard',
77: 'cell phone',
78: 'microwave',
79: 'oven',
80: 'toaster',
81: 'sink',
82: 'refrigerator',
84: 'book',
85: 'clock',
86: 'vase',
87: 'scissors',
88: 'teddy bear',
89: 'hair drier',
90: 'toothbrush'
}
clsid2catid = {k - 1: v for k, v in clsid2catid.items()}
catid2name.pop(0)
return clsid2catid, catid2name
def getImagesLabels(image, bboxes, im_id, save_image_name, catid2name, draw_threshold):
name = save_image_name.split(".")
save_json_name = name[0] + '.json'
w, h = image.size
# label的写入
dst_data = dict()
dst_data['version'] = "4.5.9"
dst_data['flags'] = {}
dst_data['shapes'] = []
dst_data['imagePath'] = save_image_name.split("\\")[-1]
dst_data['imageData'] = img_arr_to_b64(image).decode('utf-8')
dst_data['imageHeight'] = h
dst_data['imageWidth'] = w
for dt in np.array(bboxes):
if im_id != dt['image_id']:
continue
catid, bbox, score = dt['category_id'], dt['bbox'], dt['score']
if score < draw_threshold:
continue
xmin, ymin, w, h = bbox
label = catid2name[catid]
shape = dict()
shape['label'] = label
shape['points'] = [[xmin, ymin], [xmin + w, ymin + h]]
shape['group_id'] = None
shape['shape_type'] = "rectangle"
shape['flags'] = {}
dst_data['shapes'].append(shape)
print(save_json_name)
json.dump(dst_data, open(save_json_name, 'w'), indent=4)
def run(FLAGS, cfg):
draw_threshold = 0.5
output_dir = 'output'
labels_dir = 'labelme'
# build data loader
mode = 'test'
dataset = cfg['{}Dataset'.format(mode.capitalize())]
# get inference images
images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
# build model
model = create(cfg.architecture)
# normalize params for deploy
model.load_meanstd(cfg['TestReader']['sample_transforms'])
use_ema = ('use_ema' in cfg and cfg['use_ema'])
if use_ema:
ema_decay = cfg.get('ema_decay', 0.9998)
cycle_epoch = cfg.get('cycle_epoch', -1)
ema = ModelEMA(
model,
decay=ema_decay,
use_thres_step=True,
cycle_epoch=cycle_epoch)
# load weights
load_pretrain_weight(model, cfg.weights)
# # predict
dataset.set_images(images)
loader = create('TestReader')(dataset, 0)
imid2path = dataset.get_imid2path()
anno_file = dataset.get_anno()
clsid2catid, catid2name = _coco17_category()
# print(clsid2catid)
# print(catid2name)
# Run Infer
model.eval()
status = {'mode': 'test'}
results = []
for step_id, data in enumerate(loader):
status['step_id'] = step_id
# forward
outs = model(data)
# print(outs)
for key in ['im_shape', 'scale_factor', 'im_id']:
outs[key] = data[key]
for key, value in outs.items():
if hasattr(value, 'numpy'): # hasattr() 函数用于判断对象是否包含对应的属性。hasattr(object, name)
outs[key] = value.numpy()
results.append(outs)
# sniper
for outs in results:
batch_res = get_infer_results(outs, clsid2catid)
bbox_num = outs['bbox_num']
start = 0
for i, im_id in enumerate(outs['im_id']):
image_path = imid2path[int(im_id)]
image = Image.open(image_path).convert('RGB')
image = ImageOps.exif_transpose(image)
status['original_image'] = np.array(image.copy())
end = start + bbox_num[i]
bbox_res = batch_res['bbox'][start:end] \
if 'bbox' in batch_res else None
save_label_name = _get_save_image_name(labels_dir, image_path)
# 通过检测结果生成标签
getImagesLabels(image, bbox_res, int(im_id), save_label_name, catid2name, draw_threshold)
# 可视化检测结果
image = visualize_results(
image, bbox_res, None, None, None,
int(im_id), catid2name, draw_threshold)
status['result_image'] = np.array(image.copy())
# save image with detection
save_name = _get_save_image_name(output_dir, image_path)
# print("save_name: ", save_name)
logger.info("Detection bbox results save in {}".format(
save_name))
image.save(save_name, quality=95)
start = end
def main():
FLAGS = parse_args()
cfg = load_config(FLAGS.config)
cfg['use_vdl'] = FLAGS.use_vdl
cfg['vdl_log_dir'] = FLAGS.vdl_log_dir
merge_config(FLAGS.opt)
# disable npu in config by default
if 'use_npu' not in cfg:
cfg.use_npu = False
if cfg.use_gpu:
place = paddle.set_device('gpu')
elif cfg.use_npu:
place = paddle.set_device('npu')
else:
place = paddle.set_device('cpu')
if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu:
cfg['norm_type'] = 'bn'
if FLAGS.slim_config:
cfg = build_slim_model(cfg, FLAGS.slim_config, mode='test')
check_config(cfg)
check_gpu(cfg.use_gpu)
check_npu(cfg.use_npu)
check_version()
run(FLAGS, cfg)
if __name__ == '__main__':
main()