我们就按照代码运行的顺序来看,下面是我简化后的代码(原本的太多了,我只留下需要的)
作为一个萌新,写的注释肯定有错误的地方,欢迎大佬们指出。
main函数
#include
#include
#include
#include "parser.h"
#include "utils.h"
#include "cuda.h"
#include "blas.h"
#include "connected_layer.h"
#ifdef OPENCV
#include "opencv2/highgui/highgui_c.h"
#endif
extern void run_detector(int argc, char **argv);
int main(int argc, char **argv)
{
/*如果定义了_DEBUG,就检测内存是否泄露*/
#ifdef _DEBUG
_CrtSetDbgFlag(_CRTDBG_ALLOC_MEM_DF | _CRTDBG_LEAK_CHECK_DF);
#endif
/*判断参数是否符合要求*/
if(argc < 2){
fprintf(stderr, "usage: %s \n", argv[0]);
return 0;
}
/*find_int_arg:去掉-i和-i后面的一个参数,并返回后面的参数*/
gpu_index = find_int_arg(argc, argv, "-i", 0);
/*寻找argc是否有-nogpu*/
if(find_arg(argc, argv, "-nogpu")) {
gpu_index = -1;
}
/*没有用GPU就把gpu_index设为-1*/
#ifndef GPU
gpu_index = -1;
#else
if(gpu_index >= 0){
cuda_set_device(gpu_index);
}
#endif
run_detector(argc, argv);
return 0;
}
int find_int_arg(int argc, char **argv, char *arg, int def) /*寻找argc是否有arg,有则返回其位置i并且删除i和i+1上的参数,没有则返回def*/
int find_arg(int argc, char* argv[], char *arg) /*寻找argc是否有arg,有则返回1并且删除i上的参数,没有则返回0*/
run_detector函数
void run_detector(int argc, char **argv)
{
//下面几行都是获得输入命令行的内容
int dont_show = find_arg(argc, argv, "-dont_show");
int show = find_arg(argc, argv, "-show");
int http_stream_port = find_int_arg(argc, argv, "-http_port", -1);
char *out_filename = find_char_arg(argc, argv, "-out_filename", 0);
char *outfile = find_char_arg(argc, argv, "-out", 0);
char *prefix = find_char_arg(argc, argv, "-prefix", 0);
float thresh = find_float_arg(argc, argv, "-thresh", .25); // 0.24
float hier_thresh = find_float_arg(argc, argv, "-hier", .5);
int cam_index = find_int_arg(argc, argv, "-c", 0);
int frame_skip = find_int_arg(argc, argv, "-s", 0);
int num_of_clusters = find_int_arg(argc, argv, "-num_of_clusters", 5);
int width = find_int_arg(argc, argv, "-width", -1);
int height = find_int_arg(argc, argv, "-height", -1);
// extended output in test mode (output of rect bound coords)
// and for recall mode (extended output table-like format with results for best_class fit)
int ext_output = find_arg(argc, argv, "-ext_output");
int save_labels = find_arg(argc, argv, "-save_labels");
if(argc < 4){
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
return;
}
char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
int *gpus = 0;
int gpu = 0;
int ngpus = 0;
if(gpu_list){
printf("%s\n", gpu_list);
int len = strlen(gpu_list);
ngpus = 1;
int i;
for(i = 0; i < len; ++i){
if (gpu_list[i] == ',') ++ngpus;
}
gpus = calloc(ngpus, sizeof(int));
for(i = 0; i < ngpus; ++i){
gpus[i] = atoi(gpu_list);
gpu_list = strchr(gpu_list, ',')+1;
}
} else {
gpu = gpu_index;
gpus = &gpu;
ngpus = 1;
}
int clear = find_arg(argc, argv, "-clear");
char *datacfg = argv[3];
char *cfg = argv[4];
char *weights = (argc > 5) ? argv[5] : 0;
if(weights)
if(strlen(weights) > 0)
if (weights[strlen(weights) - 1] == 0x0d) weights[strlen(weights) - 1] = 0;
char *filename = (argc > 6) ? argv[6]: 0;
test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, dont_show, ext_output, save_labels);
}
test_detector函数(detector.c 1231行)
void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh,
float hier_thresh, int dont_show, int ext_output, int save_labels, char *outfile, int letter_box, int benchmark_layers)
{
//获取data/coco.data里的内容制成列表options,key是等号前的字符串,val是等号后的(key='class', val='80)
list *options = read_data_cfg(datacfg);
char *name_list = option_find_str(options, "names", "data/names.list");
int names_size = 0;
//得到coco的类别
char **names = get_labels_custom(name_list, &names_size); //get_labels(name_list);
//用来画标签的
image **alphabet = load_alphabet();
//加载模型结构
network net = parse_network_cfg_custom(cfgfile, 1, 1); // set batch=1
//加载参数
if (weightfile) {
load_weights(&net, weightfile);
}
//对net里的conv层的w和b进行归一化操作
fuse_conv_batchnorm(net);
//计算二进制权重
calculate_binary_weights(net);
//判断类别数量是否错误
if (net.layers[net.n - 1].classes != names_size) {
printf(" Error: in the file %s number of names %d that isn't equal to classes=%d in the file %s \n",
name_list, names_size, net.layers[net.n - 1].classes, cfgfile);
if (net.layers[net.n - 1].classes > names_size) getchar();
}
srand(2222222);
double time;
char buff[256];
char *input = buff;
char *json_buf = NULL;
int json_image_id = 0;
FILE* json_file = NULL;
if (outfile) {
json_file = fopen(outfile, "wb");
char *tmp = "[\n";
fwrite(tmp, sizeof(char), strlen(tmp), json_file);
}
int j;
float nms = .45; // 0.4F
//读取图片并进行识别
while (1) {
if (filename) {
strncpy(input, filename, 256);
if (strlen(input) > 0)
if (input[strlen(input) - 1] == 0x0d) input[strlen(input) - 1] = 0;
}
else {
printf("Enter Image Path: ");
fflush(stdout);
input = fgets(input, 256, stdin);
if (!input) break;
strtok(input, "\n");
}
//加载图片并且resize成(416, 416, 3)
image im = load_image(input, 0, 0, net.c);
image sized = resize_image(im, net.w, net.h);
int letterbox = 0;
layer l = net.layers[net.n - 1];
float *X = sized.data;
double time = get_time_point();
//把数据输入网络
network_predict(net, X);
printf("%s: Predicted in %lf milli-seconds.\n", input, ((double)get_time_point() - time) / 1000);
int nboxes = 0;
//获得bbbox,可以利用vs的查看定义看det里包含了什么
detection *dets = get_network_boxes(&net, im.w, im.h, thresh, hier_thresh, 0, 1, &nboxes, letterbox);
//非极大值抑制
if (nms) do_nms_sort(dets, nboxes, l.classes, nms);
//画框
draw_detections_v3(im, dets, nboxes, thresh, names, alphabet, l.classes, ext_output);
save_image(im, "predictions");
if (!dont_show) {
show_image(im, "predictions");
}
}