百度飞浆行人多目标跟踪笔记

开源地址:

PaddleDetection/configs/mot at release/2.3 · PaddlePaddle/PaddleDetection · GitHub

百度飞浆集成了多目标跟踪的多种算法,地址:

PaddleDetection/configs/mot at release/2.3 · PaddlePaddle/PaddleDetection · GitHub

deepsort:

jde

farimot:

本人测试结果如下,后续可能继续跟踪跟进。

本机代码:运行ok:

PaddleDetection-release-2.3

环境,py37

测试入口类:

tools/infer_mot.py

测试结果:有漏检,

奇怪的地方:

如果读取的是视频文件,先用ffmpeg转为图片,然后排序,读取图片列表,

直接读取图片就可以把?

 cap = cv2.VideoCapture(self.video_file)

电脑没有安装ffmpeg,所以把程序改了一下,直接读取文件夹的图片:

    def _load_video_images(self):
        if self.frame_rate == -1:
            # if frame_rate is not set for video, use cv2.VideoCapture
            cap = cv2.VideoCapture(self.video_file)
            self.frame_rate = int(cap.get(cv2.CAP_PROP_FPS))

        extension = self.video_file.split('.')[-1]
        output_path = self.video_file.replace('.{}'.format(extension), '')
        # frames_path = video2frames(self.video_file, output_path,
        #                            self.frame_rate)
        self.video_frames = natsorted(
            glob.glob(os.path.join(output_path, '*.jpg')))

        self.video_length = len(self.video_frames)
        logger.info('Length of the video: {:d} frames.'.format(
            self.video_length))
        ct = 0
        records = []
        for image in self.video_frames:
            assert image != '' and os.path.isfile(image), \
                    "Image {} not found".format(image)
            if self.sample_num > 0 and ct >= self.sample_num:
                break
            rec = {'im_id': np.array([ct]), 'im_file': image}
            if self.keep_ori_im:
                rec.update({'keep_ori_im': 1})
            self._imid2path[ct] = image
            ct += 1
            records.append(rec)
        assert len(records) > 0, "No image file found"
        return records

改后入口类:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys

# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)

import warnings

warnings.filterwarnings('ignore')

import paddle
from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config
from ppdet.engine import Tracker
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser

from ppdet.utils.logger import setup_logger

logger = setup_logger('train')

def parse_args():
    parser = ArgsParser()
    parser.add_argument('--config', type=str, default="../configs/mot/fairmot/fairmot_dla34_30e_576x320.yml", help='Video name for tracking.')
    parser.add_argument('--video_file', type=str, default="1.mp4", help='Video name for tracking.')
    parser.add_argument('--frame_rate', type=int, default=-1, help='Video frame rate for tracking.')
    parser.add_argument("--image_dir", type=str, default=None, help="Directory for images to perform inference on.")
    parser.add_argument("--det_results_dir", type=str, default='', help="Directory name for detection results.")
    parser.add_argument('--output_dir', type=str, default='output', help='Directory name for output tracking results.')
    parser.add_argument('--save_images', default=False, help='Save tracking results (image).')
    parser.add_argument('--save_videos', default=False, help='Save tracking results (video).')
    parser.add_argument('--show_image', default=True, help='Show tracking results (image).')
    parser.add_argument('--scaled', type=bool, default=False, help="Whether coords after detector outputs are scaled, False in JDE YOLOv3 "
                                                                   "True in general detector.")
    parser.add_argument("--draw_threshold", type=float, default=0.5, help="Threshold to reserve the result for visualization.")
    args = parser.parse_args()
    return args


def run(FLAGS, cfg):
    # build Tracker
    tracker = Tracker(cfg, mode='test')

    # load weights
    if cfg.architecture in ['DeepSORT']:
        if cfg.det_weights != 'None':
            tracker.load_weights_sde(cfg.det_weights, cfg.reid_weights)
        else:
            tracker.load_weights_sde(None, cfg.reid_weights)
    else:
        tracker.load_weights_jde(cfg.weights)

    # inference
    tracker.mot_predict(video_file=FLAGS.video_file, frame_rate=FLAGS.frame_rate, image_dir=FLAGS.image_dir, data_type=cfg.metric.lower(), model_type=cfg.architecture, output_dir=FLAGS.output_dir,
        save_images=FLAGS.save_images, save_videos=FLAGS.save_videos, show_image=FLAGS.show_image, scaled=FLAGS.scaled, det_results_dir=FLAGS.det_results_dir, draw_threshold=FLAGS.draw_threshold)

if __name__ == '__main__':
    FLAGS = parse_args()
    cfg = load_config(FLAGS.config)
    merge_config(FLAGS.opt)

    check_config(cfg)
    check_gpu(cfg.use_gpu)
    check_version()

    place = 'gpu:{}'.format(ParallelEnv().dev_id) if cfg.use_gpu else 'cpu'
    place = paddle.set_device(place)
    run(FLAGS, cfg)

你可能感兴趣的:(视觉跟踪,百度,行人跟踪,行人检测)