开源地址:
百度飞浆集成了多目标跟踪的多种算法,地址:
PaddleDetection/configs/mot at release/2.3 · PaddlePaddle/PaddleDetection · GitHub
deepsort:
jde
farimot:
本人测试结果如下,后续可能继续跟踪跟进。
本机代码:运行ok:
PaddleDetection-release-2.3
环境,py37
测试入口类:
tools/infer_mot.py
测试结果:有漏检,
奇怪的地方:
如果读取的是视频文件,先用ffmpeg转为图片,然后排序,读取图片列表,
直接读取图片就可以把?
cap = cv2.VideoCapture(self.video_file)
电脑没有安装ffmpeg,所以把程序改了一下,直接读取文件夹的图片:
def _load_video_images(self):
if self.frame_rate == -1:
# if frame_rate is not set for video, use cv2.VideoCapture
cap = cv2.VideoCapture(self.video_file)
self.frame_rate = int(cap.get(cv2.CAP_PROP_FPS))
extension = self.video_file.split('.')[-1]
output_path = self.video_file.replace('.{}'.format(extension), '')
# frames_path = video2frames(self.video_file, output_path,
# self.frame_rate)
self.video_frames = natsorted(
glob.glob(os.path.join(output_path, '*.jpg')))
self.video_length = len(self.video_frames)
logger.info('Length of the video: {:d} frames.'.format(
self.video_length))
ct = 0
records = []
for image in self.video_frames:
assert image != '' and os.path.isfile(image), \
"Image {} not found".format(image)
if self.sample_num > 0 and ct >= self.sample_num:
break
rec = {'im_id': np.array([ct]), 'im_file': image}
if self.keep_ori_im:
rec.update({'keep_ori_im': 1})
self._imid2path[ct] = image
ct += 1
records.append(rec)
assert len(records) > 0, "No image file found"
return records
改后入口类:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
import warnings
warnings.filterwarnings('ignore')
import paddle
from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config
from ppdet.engine import Tracker
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser
from ppdet.utils.logger import setup_logger
logger = setup_logger('train')
def parse_args():
parser = ArgsParser()
parser.add_argument('--config', type=str, default="../configs/mot/fairmot/fairmot_dla34_30e_576x320.yml", help='Video name for tracking.')
parser.add_argument('--video_file', type=str, default="1.mp4", help='Video name for tracking.')
parser.add_argument('--frame_rate', type=int, default=-1, help='Video frame rate for tracking.')
parser.add_argument("--image_dir", type=str, default=None, help="Directory for images to perform inference on.")
parser.add_argument("--det_results_dir", type=str, default='', help="Directory name for detection results.")
parser.add_argument('--output_dir', type=str, default='output', help='Directory name for output tracking results.')
parser.add_argument('--save_images', default=False, help='Save tracking results (image).')
parser.add_argument('--save_videos', default=False, help='Save tracking results (video).')
parser.add_argument('--show_image', default=True, help='Show tracking results (image).')
parser.add_argument('--scaled', type=bool, default=False, help="Whether coords after detector outputs are scaled, False in JDE YOLOv3 "
"True in general detector.")
parser.add_argument("--draw_threshold", type=float, default=0.5, help="Threshold to reserve the result for visualization.")
args = parser.parse_args()
return args
def run(FLAGS, cfg):
# build Tracker
tracker = Tracker(cfg, mode='test')
# load weights
if cfg.architecture in ['DeepSORT']:
if cfg.det_weights != 'None':
tracker.load_weights_sde(cfg.det_weights, cfg.reid_weights)
else:
tracker.load_weights_sde(None, cfg.reid_weights)
else:
tracker.load_weights_jde(cfg.weights)
# inference
tracker.mot_predict(video_file=FLAGS.video_file, frame_rate=FLAGS.frame_rate, image_dir=FLAGS.image_dir, data_type=cfg.metric.lower(), model_type=cfg.architecture, output_dir=FLAGS.output_dir,
save_images=FLAGS.save_images, save_videos=FLAGS.save_videos, show_image=FLAGS.show_image, scaled=FLAGS.scaled, det_results_dir=FLAGS.det_results_dir, draw_threshold=FLAGS.draw_threshold)
if __name__ == '__main__':
FLAGS = parse_args()
cfg = load_config(FLAGS.config)
merge_config(FLAGS.opt)
check_config(cfg)
check_gpu(cfg.use_gpu)
check_version()
place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu'
place = paddle.set_device(place)
run(FLAGS, cfg)