yolo v3 源码阅读(2):数据格式与加载

load_data_in_thread 方法去加载数据到 args.d 指针所指缓冲区中

#data.c
pthread_t load_data_in_thread(load_args args)
{
    pthread_t thread;
    struct load_args *ptr = calloc(1, sizeof(struct load_args));
    *ptr = args;
    if(pthread_create(&thread, 0, load_thread, ptr)) error("Thread creation failed");
    return thread;
}

执行load_thread方法,开启线程加载数据

//data.c
void *load_thread(void *ptr)
{
    //printf("Loading data: %d\n", rand());
    load_args a = *(struct load_args*)ptr;
    if(a.exposure == 0) a.exposure = 1;
    if(a.saturation == 0) a.saturation = 1;
    if(a.aspect == 0) a.aspect = 1;

    if (a.type == OLD_CLASSIFICATION_DATA){
        *a.d = load_data_old(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h);
    } else if (a.type == REGRESSION_DATA){
        *a.d = load_data_regression(a.paths, a.n, a.m, a.classes, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
    } else if (a.type == CLASSIFICATION_DATA){
        *a.d = load_data_augment(a.paths, a.n, a.m, a.labels, a.classes, a.hierarchy, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure, a.center);
    } else if (a.type == SUPER_DATA){
        *a.d = load_data_super(a.paths, a.n, a.m, a.w, a.h, a.scale);
    } else if (a.type == WRITING_DATA){
        *a.d = load_data_writing(a.paths, a.n, a.m, a.w, a.h, a.out_w, a.out_h);
    } else if (a.type == ISEG_DATA){
        *a.d = load_data_iseg(a.n, a.paths, a.m, a.w, a.h, a.classes, a.num_boxes, a.scale, a.min, a.max, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
    } else if (a.type == INSTANCE_DATA){
        *a.d = load_data_mask(a.n, a.paths, a.m, a.w, a.h, a.classes, a.num_boxes, a.coords, a.min, a.max, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
    } else if (a.type == SEGMENTATION_DATA){
        *a.d = load_data_seg(a.n, a.paths, a.m, a.w, a.h, a.classes, a.min, a.max, a.angle, a.aspect, a.hue, a.saturation, a.exposure, a.scale);
    } else if (a.type == REGION_DATA){
        *a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure);
    } else if (a.type == DETECTION_DATA){
        *a.d = load_data_detection(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter, a.hue, a.saturation, a.exposure);
    } else if (a.type == SWAG_DATA){
        *a.d = load_data_swag(a.paths, a.n, a.classes, a.jitter);
    } else if (a.type == COMPARE_DATA){
        *a.d = load_data_compare(a.n, a.paths, a.m, a.classes, a.w, a.h);
    } else if (a.type == IMAGE_DATA){
        *(a.im) = load_image_color(a.path, 0, 0);
        *(a.resized) = resize_image(*(a.im), a.w, a.h);
    } else if (a.type == LETTERBOX_DATA){
        *(a.im) = load_image_color(a.path, 0, 0);
        *(a.resized) = letterbox_image(*(a.im), a.w, a.h);
    } else if (a.type == TAG_DATA){
        *a.d = load_data_tag(a.paths, a.n, a.m, a.classes, a.min, a.max, a.size, a.angle, a.aspect, a.hue, a.saturation, a.exposure);
    }
    free(ptr);
    return 0;
}

以上代码中,根据args的type 属性决定了 调用哪个方法去执行load_data;

通过加断点,我们发现 运行yolo train 的时候,调用的是 load_data_detection

//data.c
data load_data_detection(int n, char **paths, int m, int w, int h, int boxes, int classes, float jitter, float hue, float saturation, float exposure)
{
    /*
     * n batch_size
     * paths 图片路径
     * m 一次取多少张图片
     * w,h 输入图片的宽高
     * boxes  根据输入网络的w,h 来决定 y 向量中数目
     */

    //随机获取到图片path
   // printf("n: %d m: %d  w: %d h:%d boxes:%d classes:%d jiter:");
    char **random_paths = get_random_paths(paths, n, m);
    int i;
    data d = {0};
    d.shallow = 0;

    //一次加载N 张图片
    d.X.rows = n;
    d.X.vals = calloc(d.X.rows, sizeof(float*));
    //每张图片  w * h * 3
    d.X.cols = h*w*3;

    //真实y 的数据是  n 维 ,5 * boxes 数量
    d.y = make_matrix(n, 5*boxes);
    for(i = 0; i < n; ++i){
        //加载图片
        image orig = load_image_color(random_paths[i], 0, 0);
        

        //得到resize 后的 空 图
        image sized = make_image(w, h, orig.c);
        //填充一半图片
        fill_image(sized, .5);

        /*
         * 这里是为数据添加抖动干扰,提高网络的泛化能力(其实就是crop,数据增广的一种).
        配置文件的jitter=0.2,则宽高最多裁剪掉或者增加原始宽高的1/5.
        */

        float dw = jitter * orig.w;
        float dh = jitter * orig.h;

        //这里进行产生随机值
        float new_ar = (orig.w + rand_uniform(-dw, dw)) / (orig.h + rand_uniform(-dh, dh));
        //float scale = rand_uniform(.25, 2);
        float scale = 1;

        float nw, nh;

        //宽小于高
        if(new_ar < 1){
            nh = scale * h;
            nw = nh * new_ar;
        } else {
            nw = scale * w;
            nh = nw / new_ar;
        }

        float dx = rand_uniform(0, w - nw);
        float dy = rand_uniform(0, h - nh);
        //对图片进行裁剪,resized 后的图像保存在sized
        place_image(orig, nw, nh, dx, dy, sized);

        //对图片进色调、曝光度等的调整
        random_distort_image(sized, hue, saturation, exposure);

        int flip = rand()%2;
        if(flip) flip_image(sized);
        d.X.vals[i] = sized.data;


        
        //图像进行变换抖动处理,需要对标签进行还原
       
        fill_truth_detection(random_paths[i], boxes, d.y.vals[i], classes, flip, -dx/w, -dy/h, nw/w, nh/h);

        free_image(orig);
    }
    free(random_paths);
    return d;
}


//读取box信息,根据之前的图像变换  改变 x,y,w,h
void fill_truth_detection(char *path, int num_boxes, float *truth, int classes, int flip, float dx, float dy, float sx, float sy)
{
    char labelpath[4096];
    find_replace(path, "images", "labels", labelpath);
    find_replace(labelpath, "JPEGImages", "labels", labelpath);

    find_replace(labelpath, "raw", "labels", labelpath);
    find_replace(labelpath, ".jpg", ".txt", labelpath);
    find_replace(labelpath, ".png", ".txt", labelpath);
    find_replace(labelpath, ".JPG", ".txt", labelpath);
    find_replace(labelpath, ".JPEG", ".txt", labelpath);
    int count = 0;
    //获取到box 信息
    box_label *boxes = read_boxes(labelpath, &count);
    //对label 信息进行处理
    randomize_boxes(boxes, count);
    //将label 还原到 变形后的图像中去
    correct_boxes(boxes, count, dx, dy, sx, sy, flip);
    //一个图最多90个框
    if(count > num_boxes) count = num_boxes;
    float x,y,w,h;
    int id;
    int i;
    int sub = 0;
    /*
    *原始object:2 0.36666666666666664 0.42824074074074076 0.17083333333333334 0.25277777777777777
    * 变换后    2 0.3666666675 0.767904878  0.170833349 0.09741116
    * 
    */
    for (i = 0; i < count; ++i) {
        x =  boxes[i].x;
        y =  boxes[i].y;
        w =  boxes[i].w;
        h =  boxes[i].h;
        id = boxes[i].id;

        //如果宽高 小于  原图的千分之一,
        if ((w < .001 || h < .001)) {
            ++sub;
            continue;
        }

        truth[(i-sub)*5+0] = x;
        truth[(i-sub)*5+1] = y;
        truth[(i-sub)*5+2] = w;
        truth[(i-sub)*5+3] = h;
        truth[(i-sub)*5+4] = id;
    }
    
    free(boxes);
}

到此,我们发现 yolo 的 load_data :

读取图片,resize 到网络宽高,然后抖动,移动整体画面曝光度色调等方式增广数据,并且 将 lable box 还原成 抖动过的图像数据中.并返回;

附录:

加载彩色图片:

image load_image_color(char *filename, int w, int h)
{
    return load_image(filename, w, h, 3);
}
image load_image(char *filename, int w, int h, int c)
{
    /*
     * c 颜色通道
     * w 宽度
     * h 高度
     */
#ifdef OPENCV
    image out = load_image_cv(filename, c);
#else
    image out = load_image_stb(filename, c);
#endif

    if((h && w) && (h != out.h || w != out.w)){
        //需要调整宽高
        image resized = resize_image(out, w, h);
        free_image(out);
        out = resized;
    }
    return out;
}

darknet 所用数据 结构体:

//matrix.h
//这里rows是一次加载到内存中的样本的个数(batch*net.subdivisions),cols就是样本的维度,**vals指向的是样本的值
typedef struct matrix{
    int  rows, cols;
    float **vals;
} matrix;

typedef struct {
    int w;
    int h;
    int c;
    float *data;
} image;


typedef struct{
//图像 宽高
    int w, h;
//X 为图像内容,y 为label 
    matrix X;
    matrix y;
    int shallow;
    int *num_boxes;
    box **boxes;
} data;

结构体:
```c
typedef struct load_args{
    int threads;
    char **paths;
    char *path;
    int n;
    int m;
    char **labels;
    int h;
    int w;
    int out_w;
    int out_h;
    int nh;
    int nw;
    int num_boxes;
    int min, max, size;
    int classes;
    int background;
    int scale;
    int center;
    int coords;
    float jitter;
    float angle;
    float aspect;
    float saturation;
    float exposure;
    float hue;
    data *d;
    image *im;
    image *resized;
    data_type type;
    tree *hierarchy;
} load_args;




你可能感兴趣的:(yolo v3 源码阅读(2):数据格式与加载)