基于百度飞桨的单/多镜头行人追踪(非官方Baseline)

基于Paddlepddle的单摄像头行人检测和目标跟踪

  • 本文禁止转载!
    • 一、安装PaddleX
    • 二、解压数据集
    • 三、设置工作路径
    • 四、生成数据集的TXT文件
    • 五、定义数据预处理模块
    • 六、定义训练集和测试集
    • 七、定义并训练模型
    • 八、评估模型
    • 九、测试模型检测结果
  • 十、添加目标追踪
  • 十一、生成MOT20提交文件
  • 我的公众号:

本文禁止转载!

B站:https://space.bilibili.com/470550823

CSDN:https://blog.csdn.net/weixin_44936889

AI Studio:https://aistudio.baidu.com/aistudio/personalcenter/thirdview/67156

Github:https://github.com/Sharpiless

一、安装PaddleX

!pip install paddlex -i https://mirror.baidu.com/pypi/simple
done
[?25h  Created wheel for pycocotools: filename=pycocotools-2.0.2-cp37-cp37m-linux_x86_64.whl size=278364 sha256=04aff7ea999743d986d883311e2adcf537537759f36adb7c25b4c8b713de36aa
  Stored in directory: /home/aistudio/.cache/pip/wheels/fb/44/67/8baa69040569b1edbd7776ec6f82c387663e724908aaa60963
Successfully built pycocotools
Installing collected packages: shapely, pycocotools, paddleslim, xlwt, paddlex
Successfully installed paddleslim-1.1.1 paddlex-1.3.7 pycocotools-2.0.2 shapely-1.7.1 xlwt-1.3.0

二、解压数据集

!unzip /home/aistudio/data/data4379/pascalvoc.zip -d /home/aistudio/work/

三、设置工作路径

import matplotlib
matplotlib.use('Agg') 
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import paddlex as pdx

os.chdir('/home/aistudio/work/')

四、生成数据集的TXT文件

PaddleX支持VOC格式数据,训练集和测试集需要定义txt文件,该文件保存图片路径和标注文件路径,格式如下:

JPEGImages/2009_003143.jpg Annotations/2009_003143.xml

JPEGImages/2012_001604.jpg Annotations/2012_001604.xml

from random import shuffle, seed

base = '/home/aistudio/work/pascalvoc/VOCdevkit/VOC2012/'

imgs = os.listdir(os.path.join(base, 'JPEGImages'))
print('total:', len(imgs))

seed(666)
shuffle(imgs)

with open(os.path.join(base, 'train_list.txt'), 'w') as f:
    for im in imgs[:5000]:
        info = 'JPEGImages/'+im+' '
        info += 'Annotations/'+im[:-4]+'.xml\n'
        f.write(info)

with open(os.path.join(base, 'val_list.txt'), 'w') as f:
    for im in imgs[-1000:]:
        info = 'JPEGImages/'+im+' '
        info += 'Annotations/'+im[:-4]+'.xml\n'
        f.write(info)

CLASSES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',

           'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',

           'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',

           'train', 'tvmonitor']

with open('labels.txt', 'w') as f:
    for v in CLASSES:
        f.write(v+'\n')
total: 17125

五、定义数据预处理模块

这里使用了图像混合、随机像素变换、随机膨胀、随即裁剪、随机水平翻转等数据增强方法。

from paddlex.det import transforms
train_transforms = transforms.Compose([
    transforms.MixupImage(mixup_epoch=250),
    transforms.RandomDistort(),
    transforms.RandomExpand(),
    transforms.RandomCrop(),
    transforms.Resize(target_size=512, interp='RANDOM'),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize(),
])

eval_transforms = transforms.Compose([
    transforms.Resize(target_size=512, interp='CUBIC'),
    transforms.Normalize(),
])

六、定义训练集和测试集

这里取出5000张图片作为训练集,1000张图片作为测试集

base = '/home/aistudio/work/pascalvoc/VOCdevkit/VOC2012/'

train_dataset = pdx.datasets.VOCDetection(
    data_dir=base,
    file_list=os.path.join(base, 'train_list.txt'),
    label_list='labels.txt',
    transforms=train_transforms,
    shuffle=True)
eval_dataset = pdx.datasets.VOCDetection(
    data_dir=base,
    file_list=os.path.join(base, 'val_list.txt'),
    label_list='labels.txt',
    transforms=eval_transforms)
2020-12-28 10:44:17 [INFO]	Starting to read file list from dataset...
2020-12-28 10:44:25 [INFO]	5000 samples in file /home/aistudio/work/pascalvoc/VOCdevkit/VOC2012/train_list.txt
creating index...
index created!
2020-12-28 10:44:25 [INFO]	Starting to read file list from dataset...
2020-12-28 10:44:27 [INFO]	1000 samples in file /home/aistudio/work/pascalvoc/VOCdevkit/VOC2012/val_list.txt
creating index...
index created!

七、定义并训练模型

这里定义了一个YOLOv3,使用DarkNet53作为主干网络;

num_classes = len(train_dataset.labels) + 1
print('class num:', num_classes)
model = pdx.det.YOLOv3(
    num_classes=num_classes, 
    backbone='MobileNetV3_large'
)
model.train(
    num_epochs=60,
    train_dataset=train_dataset,
    train_batch_size=4,
    eval_dataset=eval_dataset,
    learning_rate=0.00025,
    lr_decay_epochs=[20, 40],
    save_interval_epochs=4,
    log_interval_steps=100,
    save_dir='./YOLOv3',
    use_vdl=True)

八、评估模型

model = pdx.load_model('./YOLOv3/best_model')
model.evaluate(eval_dataset, batch_size=1, epoch_id=None, metric=None, return_details=False)

九、测试模型检测结果

import cv2
import time
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

image_name = './test.jpg'
start = time.time()
result = model.predict(image_name, eval_transforms)
print('infer time:{:.6f}s'.format(time.time()-start))
print('detected num:', len(result))

im = cv2.imread(image_name)
font = cv2.FONT_HERSHEY_SIMPLEX
threshold = 0.01

for value in result:
    xmin, ymin, w, h = np.array(value['bbox']).astype(np.int)
    cls = value['category']
    score = value['score']
    if score < threshold:
        continue
    cv2.rectangle(im, (xmin, ymin), (xmin+w, ymin+h), (0, 255, 0), 4)
    cv2.putText(im, '{:s} {:.3f}'.format(cls, score),
                    (xmin, ymin), font, 0.5, (255, 0, 0), thickness=2)

cv2.imwrite('result.jpg', im)
plt.figure(figsize=(15,12))
plt.imshow(im[:, :, [2,1,0]])
plt.show()
infer time:0.024358s
detected num: 5

基于百度飞桨的单/多镜头行人追踪(非官方Baseline)_第1张图片

十、添加目标追踪

!pip install dlib
Looking in indexes: https://mirror.baidu.com/pypi/simple/
Collecting dlib
[?25l  Downloading https://mirror.baidu.com/pypi/packages/99/2c/ef681c1c717ace11040f9e99fe22dafc843cdd6085eb6120e7ab2a5c662b/dlib-19.21.1.tar.gz (3.6MB)
     |████████████████████████████████| 3.6MB 8.6MB/s eta 0:00:01
[?25hBuilding wheels for collected packages: dlib
  Building wheel for dlib (setup.py) ... [?25l-
%cd work/
/home/aistudio/work
import dlib
import cv2


def plot_bboxes(image, bboxes, line_thickness=None):
    # Plots one bounding box on image img
    tl = line_thickness or round(
        0.002 * (image.shape[0] + image.shape[1]) / 2) + 1  # line/font thickness
    for (x1, y1, x2, y2, cls_id, pos_id) in bboxes:
        if cls_id in ['smoke', 'phone', 'eat']:
            color = (0, 0, 255)
        else:
            color = (0, 255, 0)
        if cls_id == 'eat':
            cls_id = 'eat-drink'
        c1, c2 = (x1, y1), (x2, y2)
        cv2.rectangle(image, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
        tf = max(tl - 1, 1)  # font thickness
        t_size = cv2.getTextSize(cls_id, 0, fontScale=tl / 3, thickness=tf)[0]
        c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
        cv2.rectangle(image, c1, c2, color, -1, cv2.LINE_AA)  # filled
        cv2.putText(image, '{} ID-{}'.format(cls_id, pos_id), (c1[0], c1[1] - 2), 0, tl / 3,
                    [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)

    return image


def update_tracker(target_detector, image):

    raw = image.copy()

    if target_detector.frameCounter > 2e+4:
        target_detector.frameCounter = 0

    faceIDtoDelete = []

    for faceID in target_detector.faceTracker.keys():
        trackingQuality = target_detector.faceTracker[faceID].update(image)

        if trackingQuality < 8:
            faceIDtoDelete.append(faceID)

    for faceID in faceIDtoDelete:
        target_detector.faceTracker.pop(faceID, None)
        target_detector.faceLocation1.pop(faceID, None)
        target_detector.faceLocation2.pop(faceID, None)
        target_detector.faceClasses.pop(faceID, None)

    new_faces = []

    if not (target_detector.frameCounter % target_detector.stride):

        _, bboxes = target_detector.detect(image)

        for (x1, y1, x2, y2, cls_id, _) in bboxes:
            x = int(x1)
            y = int(y1)
            w = int(x2-x1)
            h = int(y2-y1)

            x_bar = x + 0.5 * w
            y_bar = y + 0.5 * h

            matchCarID = None

            for faceID in target_detector.faceTracker.keys():
                trackedPosition = target_detector.faceTracker[faceID].get_position(
                )

                t_x = int(trackedPosition.left())
                t_y = int(trackedPosition.top())
                t_w = int(trackedPosition.width())
                t_h = int(trackedPosition.height())

                t_x_bar = t_x + 0.5 * t_w
                t_y_bar = t_y + 0.5 * t_h

                if t_x <= x_bar <= (t_x + t_w) and t_y <= y_bar <= (t_y + t_h):
                    if x <= t_x_bar <= (x + w) and y <= t_y_bar <= (y + h):
                        matchCarID = faceID

            if matchCarID is None:
                # 新出现的目标
                tracker = dlib.correlation_tracker()
                tracker.start_track(
                    image, dlib.rectangle(x, y, x + w, y + h))

                target_detector.faceTracker[target_detector.currentCarID] = tracker
                target_detector.faceLocation1[target_detector.currentCarID] = [
                    x, y, w, h]

                matchCarID = target_detector.currentCarID
                target_detector.currentCarID = target_detector.currentCarID + 1

                if cls_id == 'face':
                    pad_x = int(w * 0.15)
                    pad_y = int(h * 0.15)
                    if x > pad_x:
                        x = x-pad_x
                    if y > pad_y:
                        y = y-pad_y
                    face = raw[y:y+h+pad_y*2, x:x+w+pad_x*2]
                    new_faces.append((face, matchCarID))

                target_detector.faceClasses[matchCarID] = cls_id

    bboxes2draw = []
    for faceID in target_detector.faceTracker.keys():
        trackedPosition = target_detector.faceTracker[faceID].get_position()

        t_x = int(trackedPosition.left())
        t_y = int(trackedPosition.top())
        t_w = int(trackedPosition.width())
        t_h = int(trackedPosition.height())
        cls_id = target_detector.faceClasses[faceID]
        target_detector.faceLocation2[faceID] = [t_x, t_y, t_w, t_h]
        bboxes2draw.append(
            (t_x, t_y, t_x+t_w, t_y+t_h, cls_id, faceID)
        )

    image = plot_bboxes(image, bboxes2draw)

    return image, bboxes2draw
from os import walk
import cv2
import paddlex as pdx

class baseDet(object):

    def __init__(self):

        self.img_size = 640 # 图像大小
        self.threshold = 0.01 # 检测阈值
        self.stride = 2 # 检测步长(抽帧)
        self.model = pdx.load_model('./YOLOv3/best_model')
        self.build_config()

    def build_config(self):
        # 初始化追踪所需的变量
        self.faceTracker = {
     }
        self.faceClasses = {
     }
        self.faceLocation1 = {
     }
        self.faceLocation2 = {
     }
        self.frameCounter = 0
        self.currentCarID = 0
        self.walk_dict = {
     }
        self.recorded = []

        self.font = cv2.FONT_HERSHEY_SIMPLEX

    def feedCap(self, im):

        im, bboxes = update_tracker(self, im)

        return im, bboxes # 返回检测结果

    def detect(self, im):
        result = self.model.predict(im)
        pred_boxes = []
        for value in result:
            x1, y1, w, h = np.array(value['bbox']).astype(np.int)
            cls = value['category']
            score = value['score']
            if score > self.threshold:
                pred_boxes.append(
                    (x1, y1, x1+w, y1+h, cls, score)
                )
        return im, pred_boxes
DET = baseDet()
2021-03-11 21:26:45 [INFO]	Model[YOLOv3] loaded.
import matplotlib.pyplot as plt

%matplotlib inline

im = cv2.imread('./test.jpg')

plt.imshow(im[:, :, [2,1,0]])
plt.show()

基于百度飞桨的单/多镜头行人追踪(非官方Baseline)_第2张图片

import numpy as np
res_im, bboxes = DET.feedCap(im)
plt.imshow(res_im[:, :, [2,1,0]])
plt.show()

基于百度飞桨的单/多镜头行人追踪(非官方Baseline)_第3张图片

for k, v in DET.faceLocation2.items():
    print(k, v)
0 [18, 0, 190, 182]
1 [177, 82, 32, 82]
2 [198, 91, 27, 64]
3 [120, 109, 27, 73]

十一、生成MOT20提交文件

import os
from tqdm import tqdm

class VideoCapture(object):

    def __init__(self, img_path):
        self.name = img_path
        self.base = '../MOT20/images/test/{}/img1'
        self.img_path = self.base.format(img_path)
        self.num = len(os.listdir(self.img_path))
        self.count = 0

    def read(self):
        self.count += 1
        img = os.path.join(self.img_path, '{:06}.jpg'.format(self.count))
        image = cv2.imread(img)
        return not image is None, image

cap = VideoCapture('MOT20-04')
font = cv2.FONT_HERSHEY_SIMPLEX

for fid in tqdm(range(cap.num)):
    success, frame = cap.read()
    if not success:
        break
    res_im, bboxes = DET.feedCap(frame)
    for id_, output in DET.faceLocation2.items():
        print(k, v)
        x1, y1 = output[0], output[1]
        w, h = output[2], output[3]
        conf_ = 1.0
        bboxes.append([fid, id_, x1, y1, w,
                               h, conf_, -1, -1, -1])
        # < frame >,< id >,< bb_left >,< bb_top >,< bb_width >,< bb_height >,< conf >,< x >,< y >,< z>

with open(cap.name + '.txt', 'w') as f:
    for box in bboxes:
        line = ''
        for v in box:
            line += ',{}'.format(v)
        line = line[1:] + '\n'
                      h, conf_, -1, -1, -1])
        # < frame >,< id >,< bb_left >,< bb_top >,< bb_width >,< bb_height >,< conf >,< x >,< y >,< z>

with open(cap.name + '.txt', 'w') as f:
    for box in bboxes:
        line = ''
        for v in box:
            line += ',{}'.format(v)
        line = line[1:] + '\n'
        f.write(line)
 20%|██        | 1/5 [00:00<00:00,  6.03it/s]

我的公众号:

基于百度飞桨的单/多镜头行人追踪(非官方Baseline)_第4张图片

你可能感兴趣的:(深度学习-目标检测,运动目标检测)