【MMDetection-学习记录】 为mmdetection\demo添加video_demo.py 进行视频检测并保存

【上一篇】【MMDetection-学习记录】 Windows10操作系统下安装并运行

 

【MMDetection-学习记录】 为mmdetection\demo添加video_demo.py 进行视频检测并保存_第1张图片

 

video_demo.py

'''
Descripttion: 
version: 
Author: LiQiang
Date: 2021-01-21 11:45:22
LastEditTime: 2021-01-21 13:05:07
'''
import argparse

import cv2
import torch
import os

file_path=__file__  #当前文件所在路径
dir_path=os.path.dirname(file_path)
print(dir_path)
default_video_path=os.path.join(dir_path,'test2.mp4')

from mmdet.apis import inference_detector, init_detector


def parse_args():
    parser = argparse.ArgumentParser(description='MMDetection webcam demo')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument(
        '--device', type=str, default='cuda:0', help='CPU/CUDA device option')
    parser.add_argument(
        '--file', type=str, default=default_video_path,help='test video path')
    parser.add_argument(
    '--out', type=str, help='output video path')
    parser.add_argument(
        '--score-thr', type=float, default=0.5, help='bbox score threshold')
    args = parser.parse_args()
    return args


def main():
    args = parse_args()

    if not args.file:
        print("no input test file")
        exit(0)
    
    device = torch.device(args.device)

    model = init_detector(args.config, args.checkpoint, device=device)

    cap = cv2.VideoCapture(args.file)
    #Python OpenCV 在视频上添加文字后保存视频
    #https://blog.csdn.net/qq_41251963/article/details/111202830
    
    #获取视频宽度
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    #获取视频高度
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    #获取视频帧率
    #设置写入视频的编码格式
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    fps_video = cap.get(cv2.CAP_PROP_FPS)
    ####重要
    videoWriter = cv2.VideoWriter(args.out, fourcc, fps_video, (frame_width, frame_height))
    count=0
    print('Press "Esc", "q" or "Q" to exit.')
    while True:
        torch.cuda.empty_cache()
        ret_val, img = cap.read()
        if ret_val:
            if count<0:
                count+=1 
                print('Write {} in result Successfully!'.format(count))
                continue
            #############################
            result = inference_detector(model, img)
            ch = cv2.waitKey(1)
            if ch == 27 or ch == ord('q') or ch == ord('Q'):
                break
            frame=model.show_result(
                img, result, score_thr=args.score_thr, wait_time=1, show=False,thickness=1,font_scale=1)
            cv2.imshow('frame',frame)
            if len(frame)>=1 or frame:
                #写入视频
                videoWriter.write(frame)
                count+=1
                print('Write {} in result Successfully!'.format(count))
            #############################
            """
            # if count%24==0:  #快些看到效果
            #     result = inference_detector(model, img)

            #     ch = cv2.waitKey(1)
            #     if ch == 27 or ch == ord('q') or ch == ord('Q'):
            #         break

            #     frame=model.show_result(
            #         img, result, score_thr=args.score_thr, wait_time=1, show=False,thickness=1,font_scale=1)
            #     cv2.imshow('frame',frame)

            #     if len(frame)>=1 or frame:
            #         #写入视频
            #         videoWriter.write(frame)
            #         count+=1
            #         print('Write {} in result Successfully!'.format(count))
            # else:
            #     count+=1
            """
        else:
            print('fail!!')
            break
    cap.release()
    videoWriter.release()
    cv2.destroyAllWindows()

if __name__ == '__main__':
    main()

运行命令: --file 输入视频文件     --out 输出视频文件

python demo\video_demo.py configs\faster_rcnn\faster_rcnn_r50_fpn_1x_coco.py checkpoints\faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth --file demo\test2.mp4 --out demo\1.mp4

 

 

参考:https://www.bilibili.com/video/BV1jV411U7zb?p=3

 

你可能感兴趣的:(pyTorch,笔记,mmdetection)