python opencv自动跟踪移动目标 (DaSiamRPN)

1、代码有几个bug,cv2.drawContours函数返回值要注意

 

2、程序调用了DaSiamRPN跟踪网络,跟踪效果不错,fps也很高

 

3、代码如下

'''
读取视频、检测前景目标,调用DiameseRPN进行跟踪
'''
import cv2
import torch
import numpy as np
from os.path import realpath, dirname, join
from net import SiamRPNvot
from run_SiamRPN import SiamRPN_init, SiamRPN_track
from utils import get_axis_aligned_bbox, cxy_wh_2_rect



# load net
net = SiamRPNvot()
net.load_state_dict(torch.load(join(realpath(dirname(__file__)), 'SiamRPNVOT.model')))
net.eval().cuda()


def videoTrack():
    video_path = " "
    cap = cv2.VideoCapture(0)
    ret,frame = cap.read()
    fgbg = cv2.createBackgroundSubtractorMOG2()

    startTrack = False
    restartTrack = False
    stopTrack = False
    isTracking = False

    while(ret):
        #前景检测
        fgmask = fgbg.apply(frame)
        element = cv2.getStructuringElement(cv2.MORPH_RECT,(3,3))

        #前景处理
        fgmask = cv2.erode(fgmask,element)
        masked = cv2.bitwise_and(frame,frame,mask=fgmask)

        #轮廓查找
        img,contours,hierarchy  = cv2.findContours(fgmask,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
        # cv2.drawContours(frame,contours,-1,(0,0,255),2)

        if(startTrack):
            #找到最大轮廓
            if(len(contours)>0):
                maxContour = contours[0]
                for contour in contours:
                    if contour.size>maxContour.size:
                        maxContour = contour
                x, y, w, h = cv2.boundingRect(maxContour)
                if(w*h>400):
                    # cv2.rectangle(frame,(x,y),(x+w,y+h),(0,0,255),2)
                    target_pos, target_sz = np.array([x, y]), np.array([w, h])
                    state = SiamRPN_init(frame, target_pos, target_sz, net)
                    isTracking = True
                    startTrack = False

        if(isTracking):
            state = SiamRPN_track(state, frame)
            res = cxy_wh_2_rect(state['target_pos'], state['target_sz'])
            res = [int(l) for l in res]
            cv2.rectangle(frame, (res[0], res[1]), (res[0] + res[2], res[1] + res[3]), (0, 255, 255), 3)


        #重新跟踪
        if(restartTrack):
            if(isTracking):
                isTracking = False
            else:
                startTrack = True
                restartTrack = False

        if(stopTrack):
            isTracking = False
            stopTrack = False


        cv2.imshow("track", frame)
        # cv2.imshow("fgmask", fgmask)
        # cv2.imshow("masked", masked)

        key = cv2.waitKey(10)
        if(key == 83): #S键开始跟踪
            print("------------开始跟踪-----------------")
            startTrack = True
        elif(key == 82): #R键重新跟踪
            print("------------重新跟踪-----------------")
            restartTrack = True
        elif(key == 80): #P键停止跟踪
            print("------------停止跟踪-----------------")
            stopTrack = True

        ret,frame = cap.read()










if __name__ == '__main__':
    videoTrack()

 

你可能感兴趣的:(opencv)