category命令评估YOLO模型对每种物体检测的性能

将下面代码添加到darknet/src/detector.c中:

void print_category(FILE **fps, char *path, box *boxes, float **probs, int total, int classes, int w, int h, float thresh, float iou_thresh)
{
    int i, j;

    char labelpath[4096];
    find_replace(path, "images", "labels", labelpath);
    find_replace(labelpath, "JPEGImages", "labels", labelpath);
    find_replace(labelpath, ".jpg", ".txt", labelpath);
    find_replace(labelpath, ".JPEG", ".txt", labelpath);

    int num_labels = 0;
    box_label *truth = read_boxes(labelpath, &num_labels);

    for (i = 0; i < total; ++i){
        int class_id = max_index(probs[i], classes);
        float prob = probs[i][class_id];
        if (prob < thresh)continue;

        float best_iou = 0;
        int best_iou_id = 0;
        int correct = 0;
        for (j = 0; j < num_labels; ++j) {
            box t = { truth[j].x*w, truth[j].y*h, truth[j].w*w, truth[j].h*h };
            float iou = box_iou(boxes[i], t);
            //fprintf(stderr, "box p: %f, %f, %f, %f\n", boxes[i].x, boxes[i].y, boxes[i].w, boxes[i].h);
            //fprintf(stderr, "box t: %f, %f, %f, %f\n", t.x, t.y, t.w, t.h);
            //fprintf(stderr, "iou : %f\n", iou);
            if (iou > best_iou){
                best_iou = iou;
                best_iou_id = j;
            }
        }

        if (best_iou > iou_thresh && truth[best_iou_id].id == class_id){
            correct = 1;
        }

        float xmin = boxes[i].x - boxes[i].w / 2.;
        float xmax = boxes[i].x + boxes[i].w / 2.;
        float ymin = boxes[i].y - boxes[i].h / 2.;
        float ymax = boxes[i].y + boxes[i].h / 2.;

        if (xmin < 0) xmin = 0;
        if (ymin < 0) ymin = 0;
        if (xmax > w) xmax = w;
        if (ymax > h) ymax = h;

        fprintf(fps[class_id], "%s, %d, %d, %f, %f, %f, %f, %f, %f\n", path, class_id, correct, prob, best_iou, xmin, ymin, xmax, ymax);

    }
}



void validate_detector_category(char *datacfg, char *cfgfile, char *weightfile, char *outfile)
{
    network net = parse_network_cfg(cfgfile);
    int j;
    list *options = read_data_cfg(datacfg);
    char *valid_images = option_find_str(options, "valid", "data/train.list");
    char *name_list = option_find_str(options, "names", "data/names.list");
    char *prefix = option_find_str(options, "results", "results");
    char **names = get_labels(name_list);
    char *mapf = option_find_str(options, "map", 0);
    int *map = 0;
    if (mapf) map = read_map(mapf);

    if (weightfile){
        load_weights(&net, weightfile);
    }
    set_batch_network(&net, 1);
    fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
    srand(time(0));

    list *plist = get_paths(valid_images);
    char **paths = (char **)list_to_array(plist);

    layer l = net.layers[net.n - 1];
    int classes = l.classes;

    char buff[1024];
    FILE **fps = 0;
    if (!outfile) outfile = "paul_";
    fps = calloc(classes, sizeof(FILE *));
    for (j = 0; j < classes; ++j){
        _snprintf(buff, 1024, "%s/%s%s.txt", prefix, outfile, names[j]);
        fps[j] = fopen(buff, "w");
    }


    box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
    float **probs = calloc(l.w*l.h*l.n, sizeof(float *));
    for (j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(classes, sizeof(float *));

    int m = plist->size;
    int i = 0;
    int t;

    float thresh = .25;
    float iou_thresh = .5;
    float nms = .45;

    int nthreads = 4;
    image *val = calloc(nthreads, sizeof(image));
    image *val_resized = calloc(nthreads, sizeof(image));
    image *buf = calloc(nthreads, sizeof(image));
    image *buf_resized = calloc(nthreads, sizeof(image));
    pthread_t *thr = calloc(nthreads, sizeof(pthread_t));

    load_args args = { 0 };
    args.w = net.w;
    args.h = net.h;
    args.type = IMAGE_DATA;

    for (t = 0; t < nthreads; ++t){
        args.path = paths[i + t];
        args.im = &buf[t];
        args.resized = &buf_resized[t];
        thr[t] = load_data_in_thread(args);
    }
    time_t start = time(0);
    for (i = nthreads; i < m + nthreads; i += nthreads){
        fprintf(stderr, "%d\n", i);
        for (t = 0; t < nthreads && i + t - nthreads < m; ++t){
            pthread_join(thr[t], 0);
            val[t] = buf[t];
            val_resized[t] = buf_resized[t];
        }
        for (t = 0; t < nthreads && i + t < m; ++t){
            args.path = paths[i + t];
            args.im = &buf[t];
            args.resized = &buf_resized[t];
            thr[t] = load_data_in_thread(args);
        }
        for (t = 0; t < nthreads && i + t - nthreads < m; ++t){
            char *path = paths[i + t - nthreads];
            float *X = val_resized[t].data;
            network_predict(net, X);
            int w = val[t].w;
            int h = val[t].h;
            get_region_boxes(l, w, h, thresh, probs, boxes, 0, map, .5, 0);
            if (nms) do_nms_sort(boxes, probs, l.w*l.h*l.n, classes, nms);
            print_category(fps, path, boxes, probs, l.w*l.h*l.n, classes, w, h, thresh, iou_thresh);
            free_image(val[t]);
            free_image(val_resized[t]);
        }
    }
    for (j = 0; j < classes; ++j){
        if (fps) fclose(fps[j]);
    }
    fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));
}

修改 run_detector()函数:

void run_detector(int argc, char **argv)
{
char *prefix = find_char_arg(argc, argv, "-prefix", 0);
float thresh = find_float_arg(argc, argv, "-thresh", .24);
float hier_thresh = find_float_arg(argc, argv, "-hier", .5);
int cam_index = find_int_arg(argc, argv, "-c", 0);
int frame_skip = find_int_arg(argc, argv, "-s", 0);
if(argc < 4){
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
return;
}
char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
char *outfile = find_char_arg(argc, argv, "-out", 0);
int *gpus = 0;
int gpu = 0;
int ngpus = 0;
if(gpu_list){
printf("%s\n", gpu_list);
int len = strlen(gpu_list);
ngpus = 1;
int i;
for(i = 0; i < len; ++i){
if (gpu_list[i] == ',') ++ngpus;
}
gpus = calloc(ngpus, sizeof(int));
for(i = 0; i < ngpus; ++i){
gpus[i] = atoi(gpu_list);
gpu_list = strchr(gpu_list, ',')+1;
}
} else {
gpu = gpu_index;
gpus = &gpu;
ngpus = 1;
}

int clear = find_arg(argc, argv, "-clear");

char *datacfg = argv[3];
char *cfg = argv[4];
char *weights = (argc > 5) ? argv[5] : 0;
char *filename = (argc > 6) ? argv[6]: 0;
if(0==strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, outfile);
else if(0==strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear);
else if(0==strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile);
else if(0==strcmp(argv[2], "valid2")) validate_detector_flip(datacfg, cfg, weights, outfile);
else if(0==strcmp(argv[2], "recall")) validate_detector_recall(cfg, weights);
    //yim 2017.05.16
    else if (0 == strcmp(argv[2], "category"))validate_detector_category(datacfg, cfg, weights, outfile);

else if(0==strcmp(argv[2], "demo")) {
list *options = read_data_cfg(datacfg);
int classes = option_find_int(options, "classes", 20);
char *name_list = option_find_str(options, "names", "data/names.list");
char **names = get_labels(name_list);
demo(cfg, weights, thresh, cam_index, filename, names, classes, frame_skip, prefix, hier_thresh);
}
}

执行命令:

"detector" "category" "E:/projects&code/darknet_yolo/cfg/voc.data" "E:/projects&code/darknet_yolo/cfg/tiny-yolo-voc.cfg" "E:/projects&code/darknet_yolo/tiny-yolo-voc.weights"

result目录下会生成各类物体的val结果,有多少种物体,就会生成多少个txt文件,每个txt文件中有path, class_id, correct, prob, best_iou, xmin, ymin, xmax, ymax信息。

接下来使用evalute.py工具可以解析这些txt文件做一个总结性的评估,evalute.py脚本如下:

# coding=utf-8
# 本工具和category命令结合使用
# category是在detector.c中新增的命令,主要作用是生成每类物体的evalute结果
# 执行命令 ./darknet detector category cfg/paul.data cfg/yolo-paul.cfg backup/yolo-paul_final.weights
# result目录下会生成各类物体的val结果,将本工具放在result目录下执行,会print出各种物体的evalute结果,包括
# id,avg_iou,avg_correct_iou,avg_precision,avg_recall,avg_score
# result目录下会生成low_list和high_list,内容分别为精度和recall未达标和达标的物体种类


import os
from os import listdir, getcwd
from os.path import join
import shutil

# 共有多少类物体
class_num = 20


# 每类物体的验证结果
class CategoryValidation:
id = 0  # Category id
path = ""  # path
total_num = 0  # 标注文件中该类bounding box的总数
proposals_num = 0  # validate结果中共预测了多少个该类的bounding box
correct_num = 0  # 预测正确的bounding box(与Ground-truth的IOU大于0.5且种类正确)的数量
iou_num = 0  # 所有大于0.5的IOU的数量
iou_sum = 0  # 所有大于0.5的IOU的IOU之和
correct_iou_sum = 0  # 预测正确的bounding box的IOU之和
score_sum = 0  # 所有正确预测的bounding box的概率之和
avg_iou = 0  # 无论预测的bounding box的object的种类是否正确,所有bounding box 与最吻合的Ground-truth求出IOU,对大于0.5的IOU求平均值:avg_iou = iou_sum/iou_num
avg_correct_iou = 0  # 对预测正确的bounding box的IOU求平均值:avg_correct_iou = correct_iou_sum/correct_num
avg_precision = 0  # avg_precision = correct_num/proposals_num
avg_recall = 0  # avg_recall = correct_num/total_num
avg_score = 0  # avg_score=score_sum/correct_num

def __init__(self, path, val_cat_num):
self.path = path
f = open(path)

for line in f:
temp = line.rstrip().replace(' ', '').split(',', 9)
temp[1] = int(temp[1])
self.id = temp[1]
self.total_num = val_cat_num[self.id]
if (self.total_num):
break

for line in f:
# path, class_id, correct, prob, best_iou, xmin, ymin, xmax, ymax
temp = line.rstrip().split(', ', 9)
temp[1] = int(temp[1])
temp[2] = int(temp[2])
temp[3] = float(temp[3])
temp[4] = float(temp[4])
self.proposals_num = self.proposals_num + 1.00
if (temp[2]):
self.correct_num = self.correct_num + 1.00
self.score_sum = self.score_sum + temp[3]
self.correct_iou_sum = self.correct_iou_sum + temp[4]
if (temp[4] > 0.5):
self.iou_num = self.iou_num + 1
self.iou_sum = self.iou_sum + temp[4]

self.avg_iou = self.iou_sum / self.iou_num
self.avg_correct_iou = self.correct_iou_sum / self.correct_num
self.avg_precision = self.correct_num / self.proposals_num
self.avg_recall = self.correct_num / self.total_num
self.avg_score = self.score_sum / self.correct_num

f.close()

# 导出识别正确的图片列表
def get_correct_list(self):
f = open(self.path)
new_f_name = "correct_list_" + self.id + ".txt"
new_f = open(new_f_name, 'w')
for line in f:
temp = line.rstrip().split(', ', 9)
if (temp[2]):
new_f.write(line)
f.close()

# 导出识别错误的图片列表
def get_error_list(self):
f = open(self.path)
new_f_name = "error_list_" + self.id + ".txt"
new_f = open(new_f_name, 'w')
for line in f:
temp = line.rstrip().split(', ', 9)
if (temp[2] == 0):
new_f.write(line)
f.close()

def print_eva(self):
print("id=%d, avg_iou=%f, avg_correct_iou=%f, avg_precision=%f, avg_recall=%f, avg_score=%f \n" % (self.id,
   self.avg_iou,
   self.avg_correct_iou,
   self.avg_precision,
   self.avg_recall,
   self.avg_score))


def IsSubString(SubStrList, Str):
flag = True
for substr in SubStrList:
if not (substr in Str):
flag = False

return flag


# 获取FindPath路径下指定格式(FlagStr)的文件名列表
def GetFileList(FindPath, FlagStr=[]):
import os
FileList = []
FileNames = os.listdir(FindPath)
if (len(FileNames) > 0):
for fn in FileNames:
if (len(FlagStr) > 0):
if (IsSubString(FlagStr, fn)):
FileList.append(fn)
else:
FileList.append(fn)

if (len(FileList) > 0):
FileList.sort()

return FileList


# 获取所有物体种类的ROI数目
# path是图片列表的地址
# 返回值是一个list,list的索引是物体种类在yolo中的id,值是该种物体的ROI数量
def get_val_cat_num(path):
val_cat_num = []
for i in range(0, class_num):
val_cat_num.append(0)

f = open(path)
for line in f:
label_path = line.rstrip().replace('images', 'labels')
label_path = label_path.replace('JPEGImages', 'labels')
label_path = label_path.replace('.jpg', '.txt')
label_path = label_path.replace('.JPEG', '.txt')
label_list = open(label_path)
for label in label_list:
temp = label.rstrip().split(" ", 4)
id = int(temp[0])
val_cat_num[id] = val_cat_num[id] + 1.00
label_list.close()
f.close()
return val_cat_num


# 获取物体名list
# path是物体名list文件地址
# 返回值是一个列表,列表的索引是类的id,值为该类物体的名字
def get_name_list(path):
name_list = []
f = open(path)
for line in f:
   # temp = line.rstrip().split(',', 2)
temp = line
name_list.append(temp[1])
return name_list


wd = getcwd()
val_result_list = GetFileList(wd, ['txt'])
val_cat_num = get_val_cat_num("E:/ImageSets/VOCdevkit/VOC2012/2012_val.txt")
name_list = get_name_list("E:/projects&code/darknet_yolo/data/voc.txt")
low_list = open("low_list.log", 'w')
high_list = open("high_list.log", 'w')
for result in val_result_list:
cat = CategoryValidation(result, val_cat_num)
cat.print_eva()
if ((cat.avg_precision < 0.3) | (cat.avg_recall < 0.3)):
low_list.write("id=%d, name=%s, avg_precision=%f, avg_recall=%f \n" % (cat.id, name_list[cat.id], cat.avg_precision, cat.avg_recall))
if ((cat.avg_precision > 0.6) & (cat.avg_recall > 0.6)):
high_list.write("id=%d, name=%s, avg_precision=%f, avg_recall=%f \n" % (cat.id, name_list[cat.id], cat.avg_precision, cat.avg_recall))

low_list.close()
high_list.close()

将本工具放在result目录下执行,会print出各种物体的evalute结果,包括id,avg_iou,avg_correct_iou,avg_precision,avg_recall,avg_score。

id=0, avg_iou=0.632979, avg_correct_iou=0.632979, avg_precision=0.619048, avg_recall=0.702703, avg_score=0.734685 

id=1, avg_iou=0.656112, avg_correct_iou=0.661061, avg_precision=0.589744, avg_recall=0.575000, avg_score=0.779845 

id=2, avg_iou=0.662430, avg_correct_iou=0.663795, avg_precision=0.620253, avg_recall=0.662162, avg_score=0.670480 

id=3, avg_iou=0.628282, avg_correct_iou=0.628282, avg_precision=0.415385, avg_recall=0.397059, avg_score=0.650444 

id=4, avg_iou=0.661582, avg_correct_iou=0.665570, avg_precision=0.236364, avg_recall=0.156627, avg_score=0.664535 

id=5, avg_iou=0.667193, avg_correct_iou=0.661994, avg_precision=0.526316, avg_recall=0.625000, avg_score=0.676449 

id=6, avg_iou=0.624276, avg_correct_iou=0.625075, avg_precision=0.384181, avg_recall=0.412121, avg_score=0.647013 

id=7, avg_iou=0.652051, avg_correct_iou=0.653301, avg_precision=0.666667, avg_recall=0.845070, avg_score=0.683803 

id=8, avg_iou=0.626261, avg_correct_iou=0.624698, avg_precision=0.326389, avg_recall=0.361538, avg_score=0.657096 

id=9, avg_iou=0.651088, avg_correct_iou=0.643851, avg_precision=0.518519, avg_recall=0.700000, avg_score=0.658830 

id=10, avg_iou=0.592246, avg_correct_iou=0.584612, avg_precision=0.160000, avg_recall=0.210526, avg_score=0.709824 

id=11, avg_iou=0.646738, avg_correct_iou=0.644954, avg_precision=0.567568, avg_recall=0.724138, avg_score=0.692331 

id=12, avg_iou=0.647156, avg_correct_iou=0.651284, avg_precision=0.680000, avg_recall=0.755556, avg_score=0.770852 

id=13, avg_iou=0.640733, avg_correct_iou=0.641990, avg_precision=0.614035, avg_recall=0.636364, avg_score=0.658294 

id=14, avg_iou=0.636807, avg_correct_iou=0.637161, avg_precision=0.606667, avg_recall=0.688351, avg_score=0.633678 

id=15, avg_iou=0.631992, avg_correct_iou=0.631992, avg_precision=0.327869, avg_recall=0.317460, avg_score=0.593521 

id=16, avg_iou=0.613670, avg_correct_iou=0.626057, avg_precision=0.300000, avg_recall=0.545455, avg_score=0.653449 

id=17, avg_iou=0.610414, avg_correct_iou=0.611625, avg_precision=0.477273, avg_recall=0.656250, avg_score=0.679039 

id=18, avg_iou=0.642675, avg_correct_iou=0.642675, avg_precision=0.736842, avg_recall=0.608696, avg_score=0.700961 

id=19, avg_iou=0.637432, avg_correct_iou=0.640944, avg_precision=0.395833, avg_recall=0.441860, avg_score=0.657191 

同时result目录下会生成low_list和high_list,内容分别为精度和recall未达标和达标的物体种类。

参考博文:http://blog.csdn.net/hrsstudy/article/details/65644517?utm_source=itdadao&utm_medium=referral

你可能感兴趣的:(Object,Detect)