YOLOv3代码阅读笔记之test.py(第十篇)

对YOLOv3进行阅读,因为本人是小白,可能理解不到位的地方,请见谅。源码fork自eriklindernoren/PyTorch-YOLOv3,如需下载,请移步github,自行搜索。
本文介绍test.py
test.py主要流程为:
#1.定义evaluate评估函数
#2.解析输入的参数
#3.打印当前使用的参数
#4.解析评估数据集的路径和class_names
#5.创建model
#6.加载模型的权重
#7.调用evaluate评估函数得到评估结果
#8.打印每一种class的评估结果ap
#9.打印所有class的平均评估结果mAP

from __future__ import division

from models import *
from utils.utils import *
from utils.datasets import *
from utils.parse_config import *

import os
import sys
import time
import datetime
import argparse
import tqdm

import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Variable
import torch.optim as optim
#1.定义evaluate评估函数
#2.解析输入的参数
#3.打印当前使用的参数
#4.解析评估数据集的路径和class_names
#5.创建model
#6.加载模型的权重
#7.调用evaluate评估函数得到评估结果
#8.打印每一种class的评估结果ap
#9.打印所有class的平均评估结果mAP

def evaluate(model, path, iou_thres, conf_thres, nms_thres, img_size, batch_size):
    #输入模型model,拟评估数据集地址valid_path,iou_thres阀值,conf_thres阀值,nms_thres阀值,img_size,batch_size
    model.eval()#设置为验证模式
    
    #加载数据
    # Get dataloader
    dataset = ListDataset(path, img_size=img_size, augment=False, multiscale=False)
    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=1, collate_fn=dataset.collate_fn
    )

    Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

    labels = []
    sample_metrics = []  # List of tuples (TP, confs, pred)
    #评估第batch_i批
    for batch_i, (_, imgs, targets) in enumerate(tqdm.tqdm(dataloader, desc="Detecting objects")):

        # Extract labels#提取标签
        labels += targets[:, 1].tolist()#提取标签
        #targets:  (batch_size, 6),其中6指的是num, cls, center_x, center_y, widht, height,其中
        #num指的第num个图片
        
        # Rescale target
        targets[:, 2:] = xywh2xyxy(targets[:, 2:])#转换为左上右下形式
        targets[:, 2:] *= img_size#调整为原图大小

        imgs = Variable(imgs.type(Tensor), requires_grad=False)#输入图片组成tensor

        with torch.no_grad():
            outputs = model(imgs)#输入图片喂入model,得到outputs
            outputs = non_max_suppression(outputs, conf_thres=conf_thres, nms_thres=nms_thres)#outputs进行NMS得到最终结果

        sample_metrics += get_batch_statistics(outputs, targets, iou_threshold=iou_thres)#评估一个batch样本的性能

    # Concatenate sample statistics
    true_positives, pred_scores, pred_labels = [np.concatenate(x, 0) for x in list(zip(*sample_metrics))]
    precision, recall, AP, f1, ap_class = ap_per_class(true_positives, pred_scores, pred_labels, labels)

    return precision, recall, AP, f1, ap_class#返回一个batch_size的评估指标


if __name__ == "__main__":
    #2.解析输入的参数
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=8, help="size of each image batch")
    parser.add_argument("--model_def", type=str, default="config/yolov3.cfg", help="path to model definition file")
    parser.add_argument("--data_config", type=str, default="config/coco.data", help="path to data config file")
    parser.add_argument("--weights_path", type=str, default="weights/yolov3.weights", help="path to weights file")
    parser.add_argument("--class_path", type=str, default="data/coco.names", help="path to class label file")
    parser.add_argument("--iou_thres", type=float, default=0.5, help="iou threshold required to qualify as detected")
    parser.add_argument("--conf_thres", type=float, default=0.001, help="object confidence threshold")
    parser.add_argument("--nms_thres", type=float, default=0.5, help="iou thresshold for non-maximum suppression")
    parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
    parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
    opt = parser.parse_args()
    #3.打印当前使用的参数
    print(opt)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    #4.解析评估数据集的路径和class_names
    data_config = parse_data_config(opt.data_config)
    valid_path = data_config["valid"]
    class_names = load_classes(data_config["names"])
    
    #5.创建model
    # Initiate model
    model = Darknet(opt.model_def).to(device)
    
    #6.加载模型的权重
    if opt.weights_path.endswith(".weights"):
        # Load darknet weights
        model.load_darknet_weights(opt.weights_path)
    else:
        # Load checkpoint weights
        model.load_state_dict(torch.load(opt.weights_path))

    print("Compute mAP...")
    
    #7.调用evaluate评估函数得到评估结果
    precision, recall, AP, f1, ap_class = evaluate(
        model,
        path=valid_path,
        iou_thres=opt.iou_thres,
        conf_thres=opt.conf_thres,
        nms_thres=opt.nms_thres,
        img_size=opt.img_size,
        batch_size=8,
    )
   
   #8.打印每一种class的评估结果ap
    print("Average Precisions:")
    for i, c in enumerate(ap_class):#打印每一个class的评估ap的结果
        print(f"+ Class '{c}' ({class_names[c]}) - AP: {AP[i]}")
    
    #9.打印平均的评估结果mAP
    print(f"mAP: {AP.mean()}")#打印平均的评估结果mAP

你可能感兴趣的:(YOLOV3文章合集)