Pytorch深度学习框架YOLOv3目标检测学习笔记(五)——输入输出工程实现

在这一部分我们为检测创建输入输出管线,包含从硬盘的读取图片,做出预测,用预测画出锚框,保存到硬盘中,也会学习怎么使用摄像头实时检测工作。
我们需要安装OpenCV3
在目录文件夹中创建检测文件detector.py,在开头导入如下包

from __future__ import division
import time
import torch 
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import cv2 
from util import *
import argparse
import os 
import os.path as osp
from darknet import Darknet
import pickle as pkl
import pandas as pd
import random

创建命令行参数

因为detector.py是我们需要执行来做检测,所以使用命令行来实现这个步骤,我们用Python的ArgParse模块去做

def arg_parse():
    """
    Parse arguements to the detect module
    
    """
    
    parser = argparse.ArgumentParser(description='YOLO v3 Detection Module')
   
    parser.add_argument("--images", dest = 'images', help = 
                        "Image / Directory containing images to perform detection upon",
                        default = "imgs", type = str)
    parser.add_argument("--det", dest = 'det', help = 
                        "Image / Directory to store detections to",
                        default = "det", type = str)
    parser.add_argument("--bs", dest = "bs", help = "Batch size", default = 1)
    parser.add_argument("--confidence", dest = "confidence", help = "Object Confidence to filter predictions", default = 0.5)
    parser.add_argument("--nms_thresh", dest = "nms_thresh", help = "NMS Threshhold", default = 0.4)
    parser.add_argument("--cfg", dest = 'cfgfile', help = 
                        "Config file",
                        default = "cfg/yolov3.cfg", type = str)
    parser.add_argument("--weights", dest = 'weightsfile', help = 
                        "weightsfile",
                        default = "yolov3.weights", type = str)
    parser.add_argument("--reso", dest = 'reso', help = 
                        "Input resolution of the network. Increase to increase accuracy. Decrease to increase speed",
                        default = "416", type = str)
    
    return parser.parse_args()
    
args = arg_parse()
images = args.images
batch_size = int(args.bs)
confidence = float(args.confidence)
nms_thesh = float(args.nms_thresh)
start = 0
CUDA = torch.cuda.is_available()

加载网络

下载COCO数据集,在文件目录创建data文件夹,在Linux系统中也可以输入

mkdir data
cd data
wget https://raw.githubusercontent.com/ayooshkathuria/YOLO_v3_tutorial_from_scratch/master/data/coco.names

然后在程序中加载类文件

num_classes = 80    #For COCO
classes = load_classes("data/coco.names")
load_classes是util.py中定义的一个函数,返回一个可以将每一类的索引转换成名字的字符串目录
def load_classes(namesfile):
    fp = open(namesfile, "r")
    names = fp.read().split("\n")[:-1]
    return names

初始化网络载入权重

#Set up the neural network
print("Loading network.....")
model = Darknet(args.cfgfile)
model.load_weights(args.weightsfile)
print("Network successfully loaded")

model.net_info["height"] = args.reso
inp_dim = int(model.net_info["height"])
assert inp_dim % 32 == 0 
assert inp_dim > 32

#If there's a GPU availible, put the model on GPU
if CUDA:
    model.cuda()

#Set the model in evaluation mode
model.eval()

读取输入图片

从硬盘或者图片目录中读取图片图片路径被保存在imlist列表中

read_dir = time.time()
#Detection phase
try:
    imlist = [osp.join(osp.realpath('.'), images, img) for img in os.listdir(images)]
except NotADirectoryError:
    imlist = []
    imlist.append(osp.join(osp.realpath('.'), images))
except FileNotFoundError:
    print ("No file or directory with the name {}".format(images))
    exit()

read_dir是用来测量时间的记录点
如果包含预测结果的det目录不存在,创建一个

if not os.path.exists(args.det):
    os.makedirs(args.det)

用OpenCV加载图片


load_batch = time.time()
loaded_ims = [cv2.imread(x) for x in imlist]

load_batch又是一个节点
OpenCV将图片用numpy数组加载,有BGR颜色通道,pytorch的图片输入格式是batches×channels×height×width,颜色通道RGB,所以我们在util.py中写prep_image函数用来转换numpy数组为pytorch输入格式
在写函数之前,我们必须写letterbox_image函数重新定义图片大小,保证比例一致

def letterbox_image(img, inp_dim):
    '''resize image with unchanged aspect ratio using padding'''
    img_w, img_h = img.shape[1], img.shape[0]
    w, h = inp_dim
    new_w = int(img_w * min(w/img_w, h/img_h))
    new_h = int(img_h * min(w/img_w, h/img_h))
    resized_image = cv2.resize(img, (new_w,new_h), interpolation = cv2.INTER_CUBIC)
    
    canvas = np.full((inp_dim[1], inp_dim[0], 3), 128)

    canvas[(h-new_h)//2:(h-new_h)//2 + new_h,(w-new_w)//2:(w-new_w)//2 + new_w,  :] = resized_image
    
    return canvas

将OpenCV图像转换成网络输入

def prep_image(img, inp_dim):
    """
    Prepare image for inputting to the neural network. 
    
    Returns a Variable 
    """

    img = cv2.resize(img, (inp_dim, inp_dim))
    img = img[:,:,::-1].transpose((2,0,1)).copy()
    img = torch.from_numpy(img).float().div(255.0).unsqueeze(0)
    return img

除了转换图像,我们保留了一个原始图片的列表im_dim_list,包含原始图像的大小

#PyTorch Variables for images
im_batches = list(map(prep_image, loaded_ims, [inp_dim for x in range(len(imlist))]))

#List containing dimensions of original images
im_dim_list = [(x.shape[1], x.shape[0]) for x in loaded_ims]
im_dim_list = torch.FloatTensor(im_dim_list).repeat(1,2)

if CUDA:
    im_dim_list = im_dim_list.cuda()

创建批

leftover = 0
if (len(im_dim_list) % batch_size):
   leftover = 1

if batch_size != 1:
   num_batches = len(imlist) // batch_size + leftover            
   im_batches = [torch.cat((im_batches[i*batch_size : min((i +  1)*batch_size,
                       len(im_batches))]))  for i in range(num_batches)]  

循环识别

对每一批,我们计算从得到输入到产生输出的时间作为检测时间。write_prediction函数返回的输出的一个属性就是批处理文件的索引,我们将这个属性作为图片的索引放到imlist中,包含所有图片的地址
之后我们打印每次检测话费的时间
如果write_results函数的批输出是0维,意味着没有检测到,我们用continue跳过之后的循环

write = 0
start_det_loop = time.time()
for i, batch in enumerate(im_batches):
    #load the image 
    start = time.time()
    if CUDA:
        batch = batch.cuda()

    prediction = model(Variable(batch, volatile = True), CUDA)

    prediction = write_results(prediction, confidence, num_classes, nms_conf = nms_thesh)

    end = time.time()

    if type(prediction) == int:

        for im_num, image in enumerate(imlist[i*batch_size: min((i +  1)*batch_size, len(imlist))]):
            im_id = i*batch_size + im_num
            print("{0:20s} predicted in {1:6.3f} seconds".format(image.split("/")[-1], (end - start)/batch_size))
            print("{0:20s} {1:s}".format("Objects Detected:", ""))
            print("----------------------------------------------------------")
        continue

    prediction[:,0] += i*batch_size    #transform the atribute from index in batch to index in imlist 

    if not write:                      #If we have't initialised output
        output = prediction  
        write = 1
    else:
        output = torch.cat((output,prediction))

    for im_num, image in enumerate(imlist[i*batch_size: min((i +  1)*batch_size, len(imlist))]):
        im_id = i*batch_size + im_num
        objs = [classes[int(x[-1])] for x in output if int(x[0]) == im_id]
        print("{0:20s} predicted in {1:6.3f} seconds".format(image.split("/")[-1], (end - start)/batch_size))
        print("{0:20s} {1:s}".format("Objects Detected:", " ".join(objs)))
        print("----------------------------------------------------------")

    if CUDA:
        torch.cuda.synchronize()   

torch.cuda.synchronize确认CUDA核心与CPU同步。否则可能导致end = time.time()计时误差

图片中绘制锚框

使用try-catch块来检查有没有单目标,如果没有,退出程序

try:
    output
except NameError:
    print ("No detections were made")
    exit()

在画锚框之前,预测的是我们输出向量,不是我么图像的原始尺寸,所以将每一个锚框的属性变化成图像的原始尺寸
此外还需要重新定义尺度因为我们对图片做了池化操作

m_dim_list = torch.index_select(im_dim_list, 0, output[:,0].long())

scaling_factor = torch.min(inp_dim/im_dim_list,1)[0].view(-1,1)


output[:,[1,3]] -= (inp_dim - scaling_factor*im_dim_list[:,0].view(-1,1))/2
output[:,[2,4]] -= (inp_dim - scaling_factor*im_dim_list[:,1].view(-1,1))/2

在letterbox_image函数中,我们已经以一个尺度改变了我们图像的大小,我们需要撤回这些尺度变化使得锚框和原始图片一致

output[:,1:5] /= scaling_factor

裁剪锚框在图片外面的部分

for i in range(output.shape[0]):
    output[i, [1,3]] = torch.clamp(output[i, [1,3]], 0.0, im_dim_list[i,0])
    output[i, [2,4]] = torch.clamp(output[i, [2,4]], 0.0, im_dim_list[i,1])
如果图片中没有太多锚框,用一种颜色不合适,随机选择颜色画锚框
class_load = time.time()
colors = pkl.load(open("pallete", "rb"))

写画锚框的函数

draw = time.time()

def write(x, results, color):
    c1 = tuple(x[1:3].int())
    c2 = tuple(x[3:5].int())
    img = results[int(x[0])]
    cls = int(x[-1])
    label = "{0}".format(classes[cls])
    cv2.rectangle(img, c1, c2,color, 1)
    t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 1 , 1)[0]
    c2 = c1[0] + t_size[0] + 3, c1[1] + t_size[1] + 4
    cv2.rectangle(img, c1, c2,color, -1)
    cv2.putText(img, label, (c1[0], c1[1] + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1, [225,255,255], 1);
    return img

上面的函数从colors中随机选择颜色,在锚框的左上角创建实心矩形,将检测到的目标类别放到矩形中去,cv2.rectangle函数中的-l参数是用来创建填充矩形
每一个图片被以det_加图片名字保存,我们创建这些地址的列表,将我们检测的图片保存起来

det_names = pd.Series(imlist).apply(lambda x: "{}/det_{}".format(args.det,x.split("/")[-1]))

最后将检测的图片写入det_names地址中

list(map(cv2.imwrite, det_names, loaded_ims))
end = time.time()

打印时间

在测试最后我们打印包含代码执行所用的时间,这对我们比较超参对检测速度的影响很有用。超参例如批大小,目标置信度,非最大抑制阈值(bs,confidence,nms_thresh)可以在命令行中执行

print("SUMMARY")
print("----------------------------------------------------------")
print("{:25s}: {}".format("Task", "Time Taken (in seconds)"))
print()
print("{:25s}: {:2.3f}".format("Reading addresses", load_batch - read_dir))
print("{:25s}: {:2.3f}".format("Loading batch", start_det_loop - load_batch))
print("{:25s}: {:2.3f}".format("Detection (" + str(len(imlist)) +  " images)", output_recast - start_det_loop))
print("{:25s}: {:2.3f}".format("Output Processing", class_load - output_recast))
print("{:25s}: {:2.3f}".format("Drawing Boxes", end - draw))
print("{:25s}: {:2.3f}".format("Average time_per_img", (end - load_batch)/len(imlist)))
print("----------------------------------------------------------")


torch.cuda.empty_cache()

测试目标检测

在terminal中运行

python detect.py --images dog-cycle-car.png --det det

生成输出如下

Loading network.....
Network successfully loaded
dog-cycle-car.png    predicted in  2.456 seconds
Objects Detected:    bicycle truck dog
----------------------------------------------------------
SUMMARY
----------------------------------------------------------
Task                     : Time Taken (in seconds)

Reading addresses        : 0.002
Loading batch            : 0.120
Detection (1 images)     : 2.457
Output Processing        : 0.002
Drawing Boxes            : 0.076
Average time_per_img     : 2.657
----------------------------------------------------------

一个det_dog-cycle-car.png的图片被保存到det目录下

用视频摄像头运行检测

为了在视频或者摄像头中运行监测,代码保持不变,我们不需要定义批,而是视频框架
运行视频的代码在vedio.py文件中,代码类似与detect.py
首先,用OpenCV打开视频或者摄像头

videofile = "video.avi" #or path to the video file. 

cap = cv2.VideoCapture(videofile)  

#cap = cv2.VideoCapture(0)  for webcam

assert cap.isOpened(), 'Cannot capture source'

frames = 0

处理视频不需要批,一个时间一张图片。
每一次迭代,我们跟踪摄取框架的数量叫做farmes变量,然后根据时间分割这些数字,打印视频FPS第一张图片
使用cv2.imshow显示锚框,如果输入Q,退出视频

frames = 0  
start = time.time()

while cap.isOpened():
    ret, frame = cap.read()
    
    if ret:   
        img = prep_image(frame, inp_dim)
#        cv2.imshow("a", frame)
        im_dim = frame.shape[1], frame.shape[0]
        im_dim = torch.FloatTensor(im_dim).repeat(1,2)   
                     
        if CUDA:
            im_dim = im_dim.cuda()
            img = img.cuda()

        output = model(Variable(img, volatile = True), CUDA)
        output = write_results(output, confidence, num_classes, nms_conf = nms_thesh)


        if type(output) == int:
            frames += 1
            print("FPS of the video is {:5.4f}".format( frames / (time.time() - start)))
            cv2.imshow("frame", frame)
            key = cv2.waitKey(1)
            if key & 0xFF == ord('q'):
                break
            continue
        output[:,1:5] = torch.clamp(output[:,1:5], 0.0, float(inp_dim))

        im_dim = im_dim.repeat(output.size(0), 1)/inp_dim
        output[:,1:5] *= im_dim

        classes = load_classes('data/coco.names')
        colors = pkl.load(open("pallete", "rb"))

        list(map(lambda x: write(x, frame), output))
        
        cv2.imshow("frame", frame)
        key = cv2.waitKey(1)
        if key & 0xFF == ord('q'):
            break
        frames += 1
        print(time.time() - start)
        print("FPS of the video is {:5.2f}".format( frames / (time.time() - start)))
    else:
        break 

你可能感兴趣的:(目标检测)