这里的数据加载部分的代码由detector.c文件中train_detector函数中load_data处开始解读。
void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear)
{
list *options = read_data_cfg(datacfg);
char *train_images = option_find_str(options, "train", "data/train.list");
//store weights?
char *backup_directory = option_find_str(options, "backup", "/backup/");
srand(time(0));
//from /a/b/yolov2.cfg extract yolov2
char *base = basecfg(cfgfile); //network config
printf("%s\n", base);
float avg_loss = -1;
network **nets = calloc(ngpus, sizeof(network));
srand(time(0));
int seed = rand();
int i;
for(i = 0; i < ngpus; ++i){
srand(seed);
#ifdef GPU
cuda_set_device(gpus[i]);
#endif
//create network for every GPU
nets[i] = load_network(cfgfile, weightfile, clear);
nets[i]->learning_rate *= ngpus;
}
srand(time(0));
network *net = nets[0];
//subdivisions,why not divide?
int imgs = net->batch * net->subdivisions * ngpus;
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);
data train, buffer;
//the last layer e.g. [region] for yolov2
layer l = net->layers[net->n - 1];
int classes = l.classes;
float jitter = l.jitter;
list *plist = get_paths(train_images);
//int N = plist->size;
char **paths = (char **)list_to_array(plist);
load_args args = get_base_args(net);
args.coords = l.coords;
args.paths = paths;
args.n = imgs; //一次加载的数量
args.m = plist->size; //总的图片数量
args.classes = classes;
args.jitter = jitter;
args.num_boxes = l.max_boxes;
args.d = &buffer;
args.type = DETECTION_DATA;
//args.type = INSTANCE_DATA;
args.threads = 64;
/*n张图片以及图片上的truth box会被加载到buffer.X,buffer.y里面去*/
pthread_t load_thread = load_data(args);
....
}
在输入到load_data函数的args结构中有几个参数需要关注,args.n它是表示这一次加载的图像的数量,args.m是表示训练集中图像的总量,args.num_boxes表示一张图像中允许的最大的检测框的数量。
pthread_t load_data(load_args args)
{
pthread_t thread;
struct load_args *ptr = calloc(1, sizeof(struct load_args));
*ptr = args; //e.g. ptr->path from args.path
if(pthread_create(&thread, 0, load_threads, ptr)) error("Thread creation failed");
return thread;
}
load_data函数里面会创建一个load_threads的线程,从名字上来理解它是一个用来加载线程的线程,数据的加载并不是由这个线程直接来负责的,它更像是一个数据加载线程管理者的角色。
void *load_threads(void *ptr)
{
int i;
load_args args = *(load_args *)ptr;
if (args.threads == 0) args.threads = 1;
data *out = args.d;
int total = args.n;
free(ptr);
data *buffers = calloc(args.threads, sizeof(data));
pthread_t *threads = calloc(args.threads, sizeof(pthread_t));
for(i = 0; i < args.threads; ++i){
args.d = buffers + i;
//why not total/args.threads?
args.n = (i+1) * total/args.threads - i * total/args.threads;
threads[i] = load_data_in_thread(args);
}
//waiting for thread to load data
for(i = 0; i < args.threads; ++i){
pthread_join(threads[i], 0);
}
*out = concat_datas(buffers, args.threads);
out->shallow = 0;
for(i = 0; i < args.threads; ++i){
buffers[i].shallow = 1;
free_data(buffers[i]);
}
free(buffers);
free(threads);
return 0;
}
为什么上面说load_threads是一个数据加载线程管理者的角色就是因为在load_data_in_thread中会创建真正负责加载数据的线程,load_threads函数内部保存这些数据加载子线程的线程id,通过pthread_join函数等待这些子线程完成数据加载。创建多少个子线程由传入的args.threads成员决定,因为一次加载的图像的数量是args.n,现在有args.threads个线程去完成这项工作,所以分配到单个线程的话只需要去加载args.n/args.threads张图像。
pthread_t load_data_in_thread(load_args args)
{
pthread_t thread;
struct load_args *ptr = calloc(1, sizeof(struct load_args));
*ptr = args;
if(pthread_create(&thread, 0, load_thread, ptr)) error("Thread creation failed");
return thread;
}
void *load_thread(void *ptr)
{
//printf("Loading data: %d\n", rand());
load_args a = *(struct load_args*)ptr;
if(a.exposure == 0) a.exposure = 1;
if(a.saturation == 0) a.saturation = 1;
if(a.aspect == 0) a.aspect = 1;
if (a.type == OLD_CLASSIFICATION_DATA){
*a.d = load_data_old(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h);
}
//...省略...
} else if (a.type == DETECTION_DATA){
//detection data
*a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure);
}
//...省略...
}
free(ptr);
return 0;
}
load_data_in_thread函数中会创建真正的负责加载数据的子线程load_thread并返回线程描述符,load_thread会根据要加载的数据类型调用相应的函数,我这里只考虑DETECTION_DATA也就是检测数据的情形,因此会进一步调用load_data_detection函数。
/*
m,total of images
n,part of images for this thread
boxes,max number of boxes per picture
*/
data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, int classes, float jitter, float hue, float saturation, float exposure)
{
char **random_paths = get_random_paths(paths, n, m);
int i;
data d = {0};
d.shallow = 0;
d.X.rows = n;
d.X.vals = calloc(d.X.rows, sizeof(float*));
d.X.cols = h*w*3;
//n * boxes * 5(x,y,h,w,score)
d.y = make_matrix(n, 5*boxes);
for(i = 0; i < n; ++i){
image orig = load_image_color(random_paths[i], 0, 0); //load an origin image
image sized = make_image(w, h, orig.c); //make empty image,size is (w,h,c)
fill_image(sized, .5);
float dw = jitter * orig.w;
float dh = jitter * orig.h;
//width to height ratio after jitter
float new_ar = (orig.w + rand_uniform(-dw, dw)) / (orig.h + rand_uniform(-dh, dh));
float scale = rand_uniform(.25, 2);
float nw, nh;
//change w,h but keep the ratio,why?
if(new_ar < 1){
nh = scale * h;
nw = nh * new_ar;
} else {
nw = scale * w;
nh = nw / new_ar;
}
float dx = rand_uniform(0, w - nw);
float dy = rand_uniform(0, h - nh);
place_image(orig, nw, nh, dx, dy, sized);
random_distort_image(sized, hue, saturation, exposure);
//rand flip
int flip = rand()%2;
if(flip) flip_image(sized);
//X is ready
d.X.vals[i] = sized.data;
//y is ready
fill_truth_detection(random_paths[i], boxes, d.y.vals[i], classes, flip, -dx/w, -dy/h, nw/w, nh/h);
free_image(orig);
}
free(random_paths);
return d;
}
这里的核心是这个n次的for循环,每次循环都加载一张图像。由load_image_color将图像文件路径中的图像加载到image结构中,因为我们要求的尺寸是(w,h)的,所以紧接着通过make_image生成一张(w,h,orig.c)的空白图像。将原始图像进行一定的变换后填充到生成的空白图像中,这里面对原始图像一系列变换的意义,我至今不甚了解。整个函数中还有一个比较需要注意的点就是fill_truth_detection这个函数。
void fill_truth_detection(char *path, int num_boxes, float *truth, int classes, int flip, float dx, float dy, float sx, float sy)
{
char labelpath[4096];
find_replace(path, "images", "labels", labelpath);
find_replace(labelpath, "JPEGImages", "labels", labelpath);
find_replace(labelpath, "raw", "labels", labelpath);
find_replace(labelpath, ".jpg", ".txt", labelpath);
find_replace(labelpath, ".png", ".txt", labelpath);
find_replace(labelpath, ".JPG", ".txt", labelpath);
find_replace(labelpath, ".JPEG", ".txt", labelpath);
//上面一大堆就是根据数据集目录结构将图片路径变换成labels文档的路径
int count = 0;
box_label *boxes = read_boxes(labelpath, &count);
//disrupt the box order
randomize_boxes(boxes, count);
//(x,y,w,h)有所调整,所以truth box也要有所纠正
correct_boxes(boxes, count, dx, dy, sx, sy, flip);
if(count > num_boxes) count = num_boxes;
float x,y,w,h;
int id;
int i;
int sub = 0;
//what?
for (i = 0; i < count; ++i) {
x = boxes[i].x;
y = boxes[i].y;
w = boxes[i].w;
h = boxes[i].h;
id = boxes[i].id;
//什么意思?太小忽略?看样子是的!
if ((w < .001 || h < .001)) {
++sub;
continue;
}
truth[(i-sub)*5+0] = x;
truth[(i-sub)*5+1] = y;
truth[(i-sub)*5+2] = w;
truth[(i-sub)*5+3] = h;
truth[(i-sub)*5+4] = id;
}
free(boxes);
}