darknet代码分析先从Inference(推理)部分讲起吧,毕竟只有forward部分,而且比backward要简单一些。
首先声明,我分析的代码是基于darknentAB版本,非官方版本, 毕竟他俩的代码还是有稍许不一样。
最上层的入口函数自然是darknet.c中main(),然后就是run_detector(...) 它是training和Inference功能实现接口。再次重申,本博文讲的是后者,即推理部分,其对应的子接口是test_detector(。。。)
好了,正文刚刚开始,下面重点分析test_detector(。。。)的实现代码。
1)准备部分代码如下,我在里面对每一行都添加了具体注释
看代码注释前,讲解一下
调用darknet推理接口的形式是:
./darknet detector test cfg/xxx.data cfg/xxx.cfg yolov3.weights data/xxx.jpg
这里,./darknet是unix的可执行文件,在windows下就是后缀exe文件了。调用它就会进入main,
其后几个都是参数,存放到main的argv里面, 下面分别讲解:
detector和test参数 是作为flag来保证程序执行进入test_detector()
cfg/xxx.data 是来描述目标类型名字定义在哪个文件,训练用的样本图片和label文件在哪里,总之,训练时这个文件就应该ready
cfg/xxx.cfg 用来描述网络层次结构以及每一个结构都有哪些成员及参数, 非常重要的配置文件,训练时也应该ready
yolov3.weights 这个是训练好的权值文件
data/xxx.jpg 待推理的目标图片,地址及文件名由客户指定
//把data config文件内容读到options链表结构里面
list *options = read_data_cfg(datacfg);
//通过 names这个key,找到其对应的描述class names的data文件
char *name_list = option_find_str(options, "names", "data/names.list");
int names_size = 0;
//读取所有class names
char **names = get_labels_custom(name_list, &names_size); //get_labels(name_list);
//读取字母如a、b、c、d等字母所对应的小图片,将来会被目标检测框上的字符显示所调用
image **alphabet = load_alphabet();
//parse 描述该算法模型的网络结构的config文件并赋值给net变量
//注意 network是非常关键的变量类型,用来描述模型的网络结构
network net = parse_network_cfg_custom(cfgfile, 1); // set batch=1
//根据net网络结构,把weightfile二进制文件正确的赋给net结构里面的每一个成员
if (weightfile) {
load_weights(&net, weightfile);
}
2)另外一个小铺垫如下,即讲batch normal的输入权值进行标准化。这里面由个疑问是,按定义,BN应该是输出值进行标准化,而不是权值,此外,官方版本darknet源码没有看到这个函数的存在。
fuse_conv_batchnorm(net);
3)开始resize成nework size并计算每一层的输出值
//忍不住,吐槽一下 image resize是简单的双线性插值,为了更好的识别效果,可以考虑采用更好的插值算法。
image sized = resize_image(im, net.w, net.h);
。。。 。。。
//X就是上面的 reized image data
//net就是network类型网络结构变量,而且训练好的权值已经正确的赋值给net里面各个layer的成
//员。 然后开始根据输入和权值开始计算每一层的输出并存放在l.output变量里面
//这个是最耗费时间的函数,往往计算性能优化就在 W.X里面
network_predict(net, X);
4)计算完后,最后开始找detections,并用threshold来过滤掉不合适的目标检测。
detection *dets = get_network_boxes(&net, im.w, im.h, thresh, hier_thresh, 0, 1, &nboxes, letterbox);
if (nms) do_nms_sort(dets, nboxes, l.classes, nms);
draw_detections_v3(im, dets, nboxes, thresh, names, alphabet, l.classes, ext_output);
get_network_boxes调用了两个重要函数:
//使用3个yolo层19x19 38x38 76x76来分别检测目标检测的置信度是否大于0.25
//如果大于,则记录下来,最后返回给dets
detection *dets = make_network_boxes(net, thresh, num);
//对检测出来的dets的每个结构变量进行成员赋值,包括概率,类型名称等
fill_network_boxes(net, w, h, thresh, hier, map, relative, dets, letter);
do_nms_sort()是检测出来的目标,进行非最大化抑制。即先找出最大概率的目标,然后将iou比较大的其它目标检测概率置成0.
最后调用draw_detections_v3()来对每个detection进行检查,看它对应的哪个类型的概率最高,而且必须超过设定的threshold,才记录下来作为final detect results。 如果不超过threshold 就被放弃。 上面的接口参数可以加 -thresold来指定,否则调用缺省threshold(0.48)。 最后的最后 将这些final detects画在图片上。