Pytorch入门-使用torchvision自带的faster-rcnn进行目标检测

# -*- coding: utf-8 -*-
"""
Created on Thu Jul 30 08:47:12 2020

@author: Johnson
"""
import cv2 as cv
import  torchvision
import torch
from torchvision import transforms
import numpy as np


with open("coco.names") as f: #获取类别名称
    coco_names = [line.strip() for line in f.readlines()]


#在torchvision框架可以直接加载预训练模型
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()

'''
简单介绍一下transforms,transforms是pytorch的图像预处理包一般用compose把多个步骤整合在一起
    Resize:把给定的图片resize到given size;
	Normalize:Normalized an tensor image with mean and standard deviation;
	ToTensor:convert a PIL image to tensor (H*W*C) in range [0,255] to a torch.Tensor(C*H*W) in the range [0.0,1.0];
	ToPILImage: convert a tensor to PIL imageScale:目前已经不用了,推荐用ResizeCenterCrop;
	ResizeCenterCrop:在图片的中间区域进行裁剪;
	RandomCrop:在一个随机的位置进行裁剪;
	RandomHorizontalFlip:以0.5的概率水平翻转给定的PIL图像;
	RandomVerticalFlip:以0.5的概率竖直翻转给定的PIL图像;
	RandomResizedCrop:将PIL图像裁剪成任意大小和纵横比;
	Grayscale:将图像转换为灰度图像;
	RandomGrayscale:将图像以一定的概率转换为灰度图像;
	FiceCrop:把图像裁剪为四个角和一个中心T;
	enCropPad:填充ColorJitter:随机改变图像的亮度对比度和饱和度
'''
# tf = transforms.Compose([
#             transforms.Resize(256),
#             transforms.CenterCrop(224),
#             transforms.ToTensor(),
#             transforms.Normalize(
#             mean=[0.485, 0.456, 0.406],
#             std=[0.229, 0.224, 0.225]
#         )])

transform = transforms.Compose([transforms.ToTensor()])

def tf(image): #将图片转携程
    image = cv.resize(image, (224, 224))
    image = np.float32(image) / 255.0
    image[:,:,] -= (np.float32(0.485), np.float32(0.456), np.float32(0.406))
    image[:,:,] /= (np.float32(0.229), np.float32(0.224), np.float32(0.225))
    image = image.transpose((2, 0, 1))
    input_x = torch.from_numpy(image).unsqueeze(0)
    return input_x


#如果有GPU可以添加GPU支持
if torch.cuda.is_available():
    model.cuda()
    
#主函数
def faster_rcnn_detection(path):
    '''
    利用faster-rcnn进行目标检测

    Parameters
    ----------
    path : TYPE
        输入的图片路径

    Returns:输出有三个信息
    boxes:表示对象框
    scores:表示每个对象得分
    labels:表示对象的分类标签
    -------
    None.

    '''
    image = cv.imread(path)
    blob = transform(image)
    c,h,w = blob.shape
    input_x = blob.view(1,c,h,w)
    output = model(input_x)[0]  #这里如果是GPU.cuda()
    boxes = output['boxes'].cpu().detach().numpy()
    scores = output['scores'].cpu().detach().numpy()
    labels = output['labels'].cpu().detach().numpy()
    index = 0
    for x1,y1,x2,y2 in boxes:
        if scores[index]>0.5:
            print("boxes info",x1, y1, x2, y2)
            cv.rectangle(image, (np.int32(x1), np.int32(y1)),
                         (np.int32(x2), np.int32(y2)), (0, 255, 255), 1, 8, 0)
            label_id = labels[index]
            label_txt = coco_names[label_id]
            cv.putText(image, label_txt, (np.int32(x1), np.int32(y1)), cv.FONT_HERSHEY_PLAIN, 1.0, (0, 0, 255), 1)
        index+=1
    cv.imshow("Faster-RCNN Detection Demo", image)
    cv.waitKey(0)
    cv.destroyAllWindows()
    

def video_detection(path):
    '''
    对视频或者摄像头进行读取

    Parameters
    ----------
    path : 0或者视频路径

    Returns
    -------
    None.

    '''
    capture = cv.VideoCapture(path)
    while True:
        ret,frame = capture.read()
        if ret is not True:
            break
        blob = transform(frame)
        c,h,w = blob.shape
        input_x = blob.view(1,c,h,w)
        output = model(input_x)[0]  #这里如果是GPU.cuda()
        boxes = output['boxes'].cpu().detach().numpy()
        scores = output['scores'].cpu().detach().numpy()
        labels = output['labels'].cpu().detach().numpy()
        index = 0
        for x1,y1,x2,y2 in boxes:
            if scores[index]>0.5:
                print("boxes info",x1, y1, x2, y2)
                cv.rectangle(frame, (np.int32(x1), np.int32(y1)),
                             (np.int32(x2), np.int32(y2)), (0, 255, 255), 1, 8, 0)
                label_id = labels[index]
                label_txt = coco_names[label_id]
                cv.putText(frame, label_txt, (np.int32(x1), np.int32(y1)), cv.FONT_HERSHEY_PLAIN, 1.0, (0, 0, 255), 1)
            index+=1
        wk = cv.waitKey(1)
        if wk==27:
            break
        cv.imshow("video detection",frame)
        
        
    
# faster_rcnn_detection("D:/images_GRPC/faces_proto/05.jpg")
video_detection(0)




你可能感兴趣的:(目标检测,代码脚本)