DPM(Defomable Parts Model) 源码分析-训练(三)

申明:本文非笔者原创,原文转载自:http://blog.csdn.net/ttransposition/article/details/12954631


DPM(Defomable Parts Model)原理

首先调用格式:

example:
pascal('person', 2);   % train and evaluate a 2 component person model

pascal_train.m

[cpp]  view plain copy
  1. function model = pascal_train(cls, n) % n=2  
  2.   
  3. % model = pascal_train(cls)  
  4. % Train a model using the PASCAL dataset.  
  5.   
  6. globals;   
  7. %----------读取正负样本-----------------------  
  8. % pos.im,neg.im存储了图像路径,pos.x1..pos.y2为box,负样本无box  
  9. [pos, neg] = pascal_data(cls);  
  10.   
  11. % 按照长宽比,分成等量的两部分? 即将 component label  固定,phase2时,该值为latent variable。  spos为索引  
  12. spos = split(pos, n);  
  13.   
  14. % -----------phase 1 : train root filters using warped positives & random negatives-----------  
  15. try  
  16.   load([cachedir cls '_random']);  
  17. catch  
  18. % -----------------------------phas 1--------------------------------  
  19. % 初始化 rootfilters  
  20.   for i=1:n  
  21.     models{i} = initmodel(spos{i});  
  22.     %---------train-------------  
  23.     % model.rootfilters{i}.w  
  24.     % model.offsets{i}.w  
  25.     models{i} = train(cls, models{i}, spos{i}, neg, 1, 1, 1, 1, 2^28);  
  26.   
  27.   end  
  28.   save([cachedir cls '_random'], 'models');  
  29. end  
  30.   
  31. % -----------------phase2-------------------------------------------  
  32. % :merge models and train using latent detections & hard negatives  
  33. try   
  34.   load([cachedir cls '_hard']);  
  35. catch  
  36.   model = mergemodels(models);  
  37.   model = train(cls, model, pos, neg(1:200), 0, 0, 2, 2, 2^28, true, 0.7);  
  38.   save([cachedir cls '_hard'], 'model');  
  39. end  
  40. %----------------phase 3----------------------------------------------  
  41. % add parts and update models using latent detections & hard negatives.  
  42. try   
  43.   load([cachedir cls '_parts']);  
  44. catch  
  45.   for i=1:n  
  46.     model = addparts(model, i, 6);  
  47.   end   
  48.   % use more data mining iterations in the beginning  
  49.   model = train(cls, model, pos, neg(1:200), 0, 0, 1, 4, 2^30, true, 0.7);  
  50.   model = train(cls, model, pos, neg(1:200), 0, 0, 6, 2, 2^30, true, 0.7, true);  
  51.   save([cachedir cls '_parts'], 'model');  
  52. end  
  53.   
  54. % update models using full set of negatives.  
  55. try   
  56.   load([cachedir cls '_mine']);  
  57. catch  
  58.   model = train(cls, model, pos, neg, 0, 0, 1, 3, 2^30, true, 0.7, true, ...  
  59.                 0.003*model.numcomponents, 2);  
  60.   save([cachedir cls '_mine'], 'model');  
  61. end  
  62.   
  63. % train bounding box prediction  
  64. try  
  65.   load([cachedir cls '_final']);  
  66. catch  
  67.  % 论文中说用最小二乘,怎么直接相除了,都不考虑矩阵的奇异性  
  68.   model = trainbox(cls, model, pos, 0.7);  
  69.   save([cachedir cls '_final'], 'model');  
  70. end  

initmodel.m

[cpp]  view plain copy
  1. function model = initmodel(pos, sbin, size)  
  2.   
  3. % model = initmodel(pos, sbin, size)  
  4. % Initialize model structure.  
  5. %  
  6. % If not supplied the dimensions of the model template are computed  
  7. % from statistics in the postive examples.  
  8. %   
  9. % This should be documented! :-)  
  10. % model.sbin         8  
  11. % model.interval     10  
  12. % model.numblocks     phase 1 :单独训练rootfilter时为2,offset,rootfilter;phase 2,为 4   
  13. % model.numcomponents  1  
  14. % model.blocksizes     (1)=1,(2)= root.h*root.w/2*31  
  15. % model.regmult        0,1  
  16. % model.learnmult      20,1  
  17. % model.maxsize        root 的size   
  18. % model.minsize  
  19. % model.rootfilters{i}  
  20. %   .size               以sbin为单位,尺寸为综合各样本的h/w,area计算出来的  
  21. %   .w  
  22. %   .blocklabel        blocklabel是为编号,offset(2),rootfilter(2),partfilter(12 or less),def (12 same as part)虽然意义不同但是放在一起统一编号  
  23. % model.partfilters{i}  
  24. %   .w  
  25. %   .blocklabel  
  26. % model.defs{i}  
  27. %   .anchor  
  28. %   .w  
  29. %   .blocklabel  
  30. % model.offsets{i}  
  31. %   .w               0  
  32. %   .blocklabel       1  
  33. % model.components{i}  
  34. %   .rootindex    1  
  35. %   .parts{j}  
  36. %     .partindex  
  37. %     .defindex  
  38. %   .offsetindex    1  
  39. %   .dim             2 + model.blocksizes(1) + model.blocksizes(2)  
  40. %   .numblocks       2  
  41.   
  42. % pick mode of aspect ratios  
  43. h = [pos(:).y2]' - [pos(:).y1]' + 1;  
  44. w = [pos(:).x2]' - [pos(:).x1]' + 1;  
  45. xx = -2:.02:2;  
  46. filter = exp(-[-100:100].^2/400); % e^-25,e^25  
  47. aspects = hist(log(h./w), xx); %  
  48. aspects = convn(aspects, filter, 'same');  
  49. [peak, I] = max(aspects);  
  50. aspect = exp(xx(I)); %滤波后最大的h/w,作为最典型的h/w  
  51.   
  52. % pick 20 percentile area  
  53. areas = sort(h.*w);  
  54. area = areas(floor(length(areas) * 0.2)); % 比它大的,可以缩放,比该尺寸小的呢?  
  55. area = max(min(area, 5000), 3000); %限制在 3000-5000  
  56.   
  57. % pick dimensions  
  58. w = sqrt(area/aspect);  
  59. h = w*aspect;  
  60.   
  61. % size of HOG features  
  62. if nargin < 4  
  63.   model.sbin = 8;  
  64. else  
  65.   model.sbin = sbin;  
  66. end  
  67.   
  68. % size of root filter  
  69. if nargin < 5  
  70.   model.rootfilters{1}.size = [round(h/model.sbin) round(w/model.sbin)];  
  71. else  
  72.   model.rootfilters{1}.size = size;  
  73. end  
  74.   
  75. % set up offset   
  76. model.offsets{1}.w = 0;  
  77. model.offsets{1}.blocklabel = 1;  
  78. model.blocksizes(1) = 1;  
  79. model.regmult(1) = 0;  
  80. model.learnmult(1) = 20;  
  81. model.lowerbounds{1} = -100;  
  82.   
  83. % set up root filter  
  84. model.rootfilters{1}.w = zeros([model.rootfilters{1}.size 31]);  
  85. height = model.rootfilters{1}.size(1);  
  86. % root filter is symmetricf  
  87. width = ceil(model.rootfilters{1}.size(2)/2);  % ??? /2  
  88. model.rootfilters{1}.blocklabel = 2;  
  89. model.blocksizes(2) = width * height * 31;  
  90. model.regmult(2) = 1;  
  91. model.learnmult(2) = 1;  
  92. model.lowerbounds{2} = -100*ones(model.blocksizes(2),1);  
  93.   
  94. % set up one component model  
  95. model.components{1}.rootindex = 1;  
  96. model.components{1}.offsetindex = 1;  
  97. model.components{1}.parts = {};  
  98. model.components{1}.dim = 2 + model.blocksizes(1) + model.blocksizes(2);  
  99. model.components{1}.numblocks = 2;  
  100.   
  101. % initialize the rest of the model structure  
  102. model.interval = 10;  
  103. model.numcomponents = 1;  
  104. model.numblocks = 2;  
  105. model.partfilters = {};  
  106. model.defs = {};  
  107. model.maxsize = model.rootfilters{1}.size;  
  108. model.minsize = model.rootfilters{1}.size;  


 

learn.cc

[cpp]  view plain copy
  1. #include <stdio.h>  
  2. #include <stdlib.h>  
  3. #include <string.h>  
  4. #include <math.h>  
  5. #include <sys/time.h>  
  6. #include <errno.h>  
  7.   
  8. /* 
  9.  * Optimize LSVM objective function via gradient descent. 
  10.  * 
  11.  * We use an adaptive cache mechanism.  After a negative example 
  12.  * scores beyond the margin multiple times it is removed from the 
  13.  * training set for a fixed number of iterations. 
  14.  */  
  15.   
  16. // Data File Format  
  17. // EXAMPLE*  
  18. //   
  19. // EXAMPLE:  
  20. //  long label          ints  
  21. //  blocks              int  
  22. //  dim                 int  
  23. //  DATA{blocks}  
  24. //  
  25. // DATA:  
  26. //  block label         float  
  27. //  block data          floats  
  28. //  
  29. // Internal Binary Format  
  30. //  len           int (byte length of EXAMPLE)  
  31. //  EXAMPLE       <see above>  
  32. //  unique flag   byte  
  33.   
  34. // number of iterations  
  35. #define ITER 5000000  
  36.   
  37. // small cache parameters  
  38. #define INCACHE 3  
  39. #define WAIT 10  
  40.   
  41. // error checking  
  42. #define check(e) \  
  43. (e ? (void)0 : (printf("%s:%u error: %s\n%s\n", __FILE__, __LINE__, #e, strerror(errno)), exit(1)))  
  44.   
  45. // number of non-zero blocks in example ex  
  46. #define NUM_NONZERO(ex) (((int *)ex)[labelsize+1])  
  47.   
  48. // float pointer to data segment of example ex  
  49. #define EX_DATA(ex) ((float *)(ex + sizeof(int)*(labelsize+3)))  
  50.   
  51. // class label (+1 or -1) for the example  
  52. #define LABEL(ex) (((int *)ex)[1])  
  53.   
  54. // block label (converted to 0-based index)  
  55. #define BLOCK_IDX(data) (((int)data[0])-1)  
  56.   
  57. int labelsize;  
  58. int dim;  
  59.   
  60. // comparison function for sorting examples   
  61. // 参见 http://blog.sina.com.cn/s/blog_5155e8d401009145.html  
  62. int comp(const void *a, const void *b) {  
  63.   // sort by extended label first, and whole example second...  
  64.     
  65.   //逐字节比较的,当buf1<buf2时,返回值<0,当buf1=buf2时,返回值=0,当buf1>buf2时,返回值>0  
  66.   // 先比较这五个量 [label id level x y],也就是说按照 样本类别->id->level->x->y排序样本  
  67.   int c = memcmp(*((char **)a) + sizeof(int),   
  68.          *((char **)b) + sizeof(int),   
  69.          labelsize*sizeof(int));// 5  
  70.   if (c) //label 不相等  
  71.     return c;  
  72.     
  73.   // labels are the same ,怎么可能会一样呢 id在正负样本集内从1开始是递增的啊  phase 2 阶段同一张图片产生的样本,id都是一样的  
  74.   int alen = **((int **)a);  
  75.   int blen = **((int **)b);  
  76.   if (alen == blen) //长度一样  
  77.     return memcmp(*((char **)a) + sizeof(int),   
  78.           *((char **)b) + sizeof(int),   
  79.           alen); //真霸气,所有字节都比较……  
  80.   return ((alen < blen) ? -1 : 1);//按长度排序  
  81. }  
  82.   
  83. // a collapsed example is a sequence of examples  
  84. struct collapsed {  
  85.   char **seq;  
  86.   int num;  
  87. };  
  88.   
  89. // set of collapsed examples  
  90. struct data {  
  91.   collapsed *x;  
  92.   int num;  
  93.   int numblocks;  
  94.   int *blocksizes;  
  95.   float *regmult;  
  96.   float *learnmult;  
  97. };  
  98.   
  99. // seed the random number generator with the current time  
  100. void seed_time() {  
  101.  struct timeval tp;  
  102.  check(gettimeofday(&tp, NULL) == 0);  
  103.  srand48((long)tp.tv_usec);  
  104. }  
  105.   
  106. static inline double min(double x, double y) { return (x <= y ? x : y); }  
  107. static inline double max(double x, double y) { return (x <= y ? y : x); }  
  108.   
  109. // gradient descent  
  110. //---------------参照论文公式17 后的步骤---------------------------------------  
  111. void gd(double C, double J, data X, double **w, double **lb) {  
  112. //  C=0.0002, J=1, X, w==0, lb==-100);  
  113. //      
  114.   int num = X.num; //组数  
  115.     
  116.   // state for random permutations  
  117.   int *perm = (int *)malloc(sizeof(int)*X.num);  
  118.   check(perm != NULL);  
  119.   
  120.   // state for small cache  
  121.   int *W = (int *)malloc(sizeof(int)*num);  
  122.   check(W != NULL);  
  123.   for (int j = 0; j < num; j++)  
  124.     W[j] = 0;  
  125.   
  126.   int t = 0;  
  127.   while (t < ITER) {  // 5000000 ,霸气……  
  128.     // pick random permutation  
  129.     for (int i = 0; i < num; i++) //组数  
  130.       perm[i] = i;  
  131.     //-------打乱顺序-----  
  132.     // 论文中是随机选择一个样本,这里是随机排好序,再顺序取。  
  133.     // 类似于随机取,但是这里能保证取到全部样本,避免单个样本重复被抽到,重复作用  
  134.     for (int swapi = 0; swapi < num; swapi++) {  
  135.       int swapj = (int)(drand48()*(num-swapi)) + swapi; //drand48 产生 0-1之间的均匀分布  
  136.       int tmp = perm[swapi];  
  137.       perm[swapi] = perm[swapj];  
  138.       perm[swapj] = tmp;  
  139.     }  
  140.   
  141.     // count number of examples in the small cache  
  142.     int cnum = 0; //下面的循环部分的实际循环次数  
  143.     for (int i = 0; i < num; i++) {  
  144.       if (W[i] <= INCACHE) // 3  
  145.         cnum++;  
  146.     }  
  147.     //-------------------------------------------------------  
  148.     for (int swapi = 0; swapi < num; swapi++) {  
  149.       // select example  
  150.       int i = perm[swapi];  
  151.       collapsed x = X.x[i];  
  152.   
  153.       // skip if example is not in small cache  
  154.       //负样本分对一次+1,分错一次清为0  
  155.       //连续三次都分对了,那么这个样本很有可能是 easy 样本  
  156.       //直接让他罚停四次迭代  
  157.       if (W[i] > INCACHE) { //3  
  158.             W[i]--;  
  159.             continue;  
  160.       }  
  161.   
  162.       // learning rate  
  163.       double T = t + 1000.0; //学习率,直接1/t太大了  
  164.       double rateX = cnum * C / T;  
  165.       double rateR = 1.0 / T;  
  166.   
  167.       if (t % 10000 == 0) {  
  168.         printf(".");  
  169.         fflush(stdout); //清除文件缓冲区,文件以写方式打开时将缓冲区内容写入文件  
  170.       }  
  171.       t++;  
  172.         
  173.       // compute max over latent placements  
  174.       //  -----step 3----  
  175.       int M = -1;  
  176.       double V = 0;  
  177.       // 组内循环,选择 Zi=argmax β*f 即文中的第3部  
  178.       // 训练rootfiter时,x.num=1,因为随机产生的负样本其id不同  
  179.       for (int m = 0; m < x.num; m++) {   
  180.         double val = 0;  
  181.         char *ptr = x.seq[m];  
  182.         float *data = EX_DATA(ptr); //特征数据的地址 第9个数据开始,  
  183.         //后面跟着是 block1 label | block2 data|block2 lable | block2 data    
  184.         //                 1      |       1    |     2       |  h*w/2*31个float  
  185.         int blocks = NUM_NONZERO(ptr); // phase 1,phase 2 : 2 个,offset,rootfilter  
  186.         for (int j = 0; j < blocks; j++) {  
  187.           int b = BLOCK_IDX(data); //   
  188.           data++;  
  189.           for (int k = 0; k < X.blocksizes[b]; k++)//(1)=1,(2)= root.h*root.w/2*31  
  190.             val += w[b][k] * data[k]; //第一次循环是0  
  191.           data += X.blocksizes[b];  
  192.         }  
  193.         if (M < 0 || val > V) {  
  194.           M = m;  
  195.           V = val;  
  196.         }  
  197.       }  
  198.         
  199.       // update model  
  200.       //-----step.4 也算了step.5 的一半 ---------------  
  201.       // 梯度下降,减小 w  
  202.       for (int j = 0; j < X.numblocks; j++) {// 2  
  203.         double mult = rateR * X.regmult[j] * X.learnmult[j]; // 0,1  20,1,1/T,对于block2,学习率at就是 1/t,block 1 为0  
  204.         for (int k = 0; k < X.blocksizes[j]; k++) {  
  205.           w[j][k] -= mult * w[j][k]; //不管是分对了,还是分错了,都要减掉 at*β,见公式17下的4,5   
  206.         }  
  207.       }  
  208.       char *ptr = x.seq[M];  
  209.       int label = LABEL(ptr);  
  210.       //----step.5----------分错了,往梯度的负方向移动  
  211.       if (label * V < 1.0)   
  212.       {  
  213.         W[i] = 0;  
  214.         float *data = EX_DATA(ptr);  
  215.         int blocks = NUM_NONZERO(ptr);  
  216.         for (int j = 0; j < blocks; j++) {  
  217.             int b = BLOCK_IDX(data);  
  218.             //  yi*cnum * C / T*1,见论文中 公式16,17  
  219.             double mult = (label > 0 ? J : -1) * rateX * X.learnmult[b];         
  220.             data++;  
  221.             for (int k = 0; k < X.blocksizes[b]; k++)  
  222.                 w[b][k] += mult * data[k];  
  223.             data += X.blocksizes[b];  
  224.         }  
  225.       } else if (label == -1)   
  226.       {  
  227.             if (W[i] == INCACHE) //3  
  228.                 W[i] = WAIT; //10  
  229.             else  
  230.                 W[i]++;  
  231.       }  
  232.     }  
  233.   
  234.     // apply lowerbounds  
  235.     for (int j = 0; j < X.numblocks; j++) {  
  236.       for (int k = 0; k < X.blocksizes[j]; k++) {  
  237.         w[j][k] = max(w[j][k], lb[j][k]);  
  238.       }  
  239.     }  
  240.   
  241.   }  
  242.   
  243.   free(perm);  
  244.   free(W);  
  245. }  
  246.   
  247. // score examples  
  248. double *score(data X, char **examples, int num, double **w) {  
  249.   double *s = (double *)malloc(sizeof(double)*num);  
  250.   check(s != NULL);  
  251.   for (int i = 0; i < num; i++) {  
  252.     s[i] = 0.0;  
  253.     float *data = EX_DATA(examples[i]);  
  254.     int blocks = NUM_NONZERO(examples[i]);  
  255.     for (int j = 0; j < blocks; j++) {  
  256.       int b = BLOCK_IDX(data);  
  257.       data++;  
  258.       for (int k = 0; k < X.blocksizes[b]; k++)  
  259.         s[i] += w[b][k] * data[k];  
  260.       data += X.blocksizes[b];  
  261.     }  
  262.   }  
  263.   return s;    
  264. }  
  265.   
  266. // merge examples with identical labels  
  267. void collapse(data *X, char **examples, int num) {  
  268. //&X, sorted, num_unique  
  269.   collapsed *x = (collapsed *)malloc(sizeof(collapsed)*num);  
  270.   check(x != NULL);  
  271.   int i = 0;  
  272.   x[0].seq = examples;  
  273.   x[0].num = 1;  
  274.   for (int j = 1; j < num; j++) {  
  275.     if (!memcmp(x[i].seq[0]+sizeof(int), examples[j]+sizeof(int),   
  276.         labelsize*sizeof(int))) {  
  277.       x[i].num++; //如果label 五个量相同  
  278.     } else {  
  279.       i++;  
  280.       x[i].seq = &(examples[j]);  
  281.       x[i].num = 1;  
  282.     }  
  283.   }  
  284.   X->x = x;  
  285.   X->num = i+1;    
  286. }  
  287.   
  288. //调用参数 C=0.0002, J=1, hdrfile, datfile, modfile, inffile, lobfile  
  289. int main(int argc, char **argv) {    
  290.   seed_time();  
  291.   int count;  
  292.   data X;  
  293.   
  294.   // command line arguments  
  295.   check(argc == 8);  
  296.   double C = atof(argv[1]);  
  297.   double J = atof(argv[2]);  
  298.   char *hdrfile = argv[3];  
  299.   char *datfile = argv[4];  
  300.   char *modfile = argv[5];  
  301.   char *inffile = argv[6];  
  302.   char *lobfile = argv[7];  
  303.   
  304.   // read header file  
  305.   FILE *f = fopen(hdrfile, "rb");  
  306.   check(f != NULL);  
  307.   int header[3];  
  308.   count = fread(header, sizeof(int), 3, f);  
  309.   check(count == 3);  
  310.   int num = header[0]; //正负样本总数  
  311.   labelsize = header[1]; // labelsize = 5;  [label id level x y]  
  312.   X.numblocks = header[2]; // 2  
  313.   X.blocksizes = (int *)malloc(X.numblocks*sizeof(int)); //(1)=1,(2)= root.h*root.w/2*31  
  314.   count = fread(X.blocksizes, sizeof(int), X.numblocks, f);  
  315.   check(count == X.numblocks);  
  316.   X.regmult = (float *)malloc(sizeof(float)*X.numblocks); //0 ,1  
  317.   check(X.regmult != NULL);  
  318.   count = fread(X.regmult, sizeof(float), X.numblocks, f);  
  319.   check(count == X.numblocks);  
  320.   X.learnmult = (float *)malloc(sizeof(float)*X.numblocks);//20, 1  
  321.   check(X.learnmult != NULL);  
  322.   count = fread(X.learnmult, sizeof(float), X.numblocks, f);  
  323.   check(count == X.numblocks);  
  324.   check(num != 0);  
  325.   fclose(f);  
  326.   printf("%d examples with label size %d and %d blocks\n",  
  327.      num, labelsize, X.numblocks);  
  328.   printf("block size, regularization multiplier, learning rate multiplier\n");  
  329.   dim = 0;  
  330.   for (int i = 0; i < X.numblocks; i++) {  
  331.     dim += X.blocksizes[i];  
  332.     printf("%d, %.2f, %.2f\n", X.blocksizes[i], X.regmult[i], X.learnmult[i]);  
  333.   }  
  334.   
  335.   // ---------------从 datfile 读取  正负 examples----------------  
  336.   // examples [i] 存储了第i个样本的信息 长度为 1 int + 7 int +dim 个float + 1 byte  
  337.   // 1 int      legth 样本包括信息头在内的总字节长度  
  338.   // 7 int      [1/-1 id 0 0 0 2 dim] ,id为样本编号,[label id level centry_x centry_y],2是block个数  
  339.   // dim float  feature,dim=2+1+root.h*root.w/2*31,意义如下  
  340.   //         block1 label | block2 data|block2 lable | block2 data  
  341.   //               1      |       1    |     2       |  h*w/2*31个float  
  342.   // 1 byte     unique=0  
  343.   f = fopen(datfile, "rb");  
  344.   check(f != NULL);  
  345.   printf("Reading examples\n");  
  346.     
  347.   //+,-example数据  
  348.   char **examples = (char **)malloc(num*sizeof(char *));   
  349.     
  350.   check(examples != NULL);  
  351.     for (int i = 0; i < num; i++) {  
  352.     // we use an extra byte in the end of each example to mark unique  
  353.     // we use an extra int at the start of each example to store the   
  354.     // example's byte length (excluding unique flag and this int)  
  355.     //[legth label id level x y  unique] unique=0  
  356.     int buf[labelsize+2];   
  357.     //写入时的值为[1/-1 i 0 0 0 2 dim]   
  358.     count = fread(buf, sizeof(int), labelsize+2, f);  
  359.     check(count == labelsize+2);  
  360.     // byte length of an example's data segment  
  361.       
  362.     //---前面七个是头,后面dim个float是样本特征数据,dim=2+1+root.h*root.w/2*31  
  363.     int len = sizeof(int)*(labelsize+2) + sizeof(float)*buf[labelsize+1];     
  364.     // memory for data, an initial integer, and a final byte  
  365.     examples[i] = (char *)malloc(sizeof(int)+len+1);  
  366.       
  367.     check(examples[i] != NULL);  
  368.     // set data segment's byte length  
  369.     ((int *)examples[i])[0] = len;  
  370.     // set the unique flag to zero  
  371.     examples[i][sizeof(int)+len] = 0;  
  372.     // copy label data into example  
  373.     for (int j = 0; j < labelsize+2; j++)  
  374.       ((int *)examples[i])[j+1] = buf[j];  
  375.     // read the rest of the data segment into the example  
  376.     count = fread(examples[i]+sizeof(int)*(labelsize+3), 1,   
  377.           len-sizeof(int)*(labelsize+2), f);  
  378.     check(count == len-sizeof(int)*(labelsize+2));  
  379.   }  
  380.   fclose(f);  
  381.   printf("done\n");  
  382.   
  383.   // sort  
  384.   printf("Sorting examples\n");  
  385.   char **sorted = (char **)malloc(num*sizeof(char *));  
  386.   check(sorted != NULL);  
  387.   memcpy(sorted, examples, num*sizeof(char *));  
  388.     
  389.   //qsort 库函数,真正的比较函数为 comp  
  390.   //从小到大,快速排序  
  391.   //依次按照 样本类别->id->level->cx->cy  排序样本  
  392.   //如果前面五个量都一样……  
  393.   //1.等长度,比较所有字节;  
  394.   //2.谁长谁小,长度不同是因为不同的component的 尺寸不一致   
  395.     
  396.   qsort(sorted, num, sizeof(char *), comp);   
  397.   printf("done\n");  
  398.   
  399.   // find unique examples  
  400.   // 唯一的样本,unique flag=1,  
  401.   // 相同的样本第一个样本的unique flag为1,其余为0 ,有的样本的位置被,unique替代了,但是并没有完全删除掉  
  402.   int i = 0;  
  403.   int len = *((int *)sorted[0]); //负样本的第一个  
  404.   sorted[0][sizeof(int)+len] = 1; // unique flag 置 1  
  405.   for (int j = 1; j < num; j++) {  
  406.     int alen = *((int *)sorted[i]);  
  407.     int blen = *((int *)sorted[j]);  
  408.     if (alen != blen || memcmp(sorted[i] + sizeof(int), sorted[j] + sizeof(int), alen)) //component不同 || 不同样本  
  409.     {  
  410.       i++;  
  411.       sorted[i] = sorted[j];  
  412.       sorted[i][sizeof(int)+blen] = 1; //标记为 unique  
  413.     }  
  414.   }  
  415.   int num_unique = i+1;  
  416.   printf("%d unique examples\n", num_unique);  
  417.   
  418.   // -------------------collapse examples----------------  
  419.   // 前面是找完全不一样的样本,这里是分组  
  420.   // label 的五个量 [label id level centry_x centry_y] 相同的分为一组,在detect时,写入了datfile   
  421.   // 负样本的 cx,cy都是相对于整张图片的,正样本是相对于剪切后的图像  
  422.   // 前面五个全相同,  
  423.   // 对于phase1 不可能,因为正负样本的id都不相同  
  424.   // 对于phase2 正样本只保留了最有可能是正样本的样本,只有一种情况,  
  425.   // rootfilter1,rootfilter2在同一张图片(id相同),检测出来的 Hard负样本 的cx,cy相同,因此一组最多应该只能出现2个 (待验证)  
  426.   // 原因是此时的latent variable 为(cx,cy,component),上述情况相下,我们只能保留component1或者component2  
  427.   // 后续训练时,这两个量是连续使用的,为什么呢??  
  428.   // collapse.seq(char **) 记录了每一组的第一个样本  
  429.   // collapse.num 每组的个数  
  430.   // X.num 组数  
  431.   // X.x=&collapse[0],也就是第一个 collapse的地址  
  432.   collapse(&X, sorted, num_unique);  
  433.   printf("%d collapsed examples\n", X.num);  
  434.   
  435.   // initial model  
  436.   // 读modfile文件,得到w的初始值。phase 1 初始化为全 0,phase 2 为上一次训练的结果……  
  437.   double **w = (double **)malloc(sizeof(double *)*X.numblocks);//2  
  438.   check(w != NULL);  
  439.   f = fopen(modfile, "rb");  
  440.   for (int i = 0; i < X.numblocks; i++) {  
  441.     w[i] = (double *)malloc(sizeof(double)*X.blocksizes[i]); //(1)=1,(2)= root.h*root.w/2*31  
  442.     check(w[i] != NULL);  
  443.     count = fread(w[i], sizeof(double), X.blocksizes[i], f);  
  444.     check(count == X.blocksizes[i]);  
  445.   }  
  446.   fclose(f);  
  447.   
  448.   // lower bounds  
  449.   // 读lobfile文件,初始化为全 滤波器参数下线-100 ……  
  450.   double **lb = (double **)malloc(sizeof(double *)*X.numblocks);  
  451.   check(lb != NULL);  
  452.   f = fopen(lobfile, "rb");  
  453.   for (int i = 0; i < X.numblocks; i++) {  
  454.     lb[i] = (double *)malloc(sizeof(double)*X.blocksizes[i]);  
  455.     check(lb[i] != NULL);  
  456.     count = fread(lb[i], sizeof(double), X.blocksizes[i], f);  
  457.     check(count == X.blocksizes[i]);  
  458.   }  
  459.   fclose(f);  
  460.     
  461.   
  462.   printf("Training");  
  463.   //-------------------------------- train -------------------------------  
  464.   //-----梯度下降发训练参数 w,参见论文 公式17 后面的步骤  
  465.   gd(C, J, X, w, lb);  
  466.   printf("done\n");  
  467.   
  468.   // save model  
  469.   printf("Saving model\n");  
  470.   f = fopen(modfile, "wb");  
  471.   check(f != NULL);  
  472.   //   存储 block1,block2的训练结果,w  
  473.   for (int i = 0; i < X.numblocks; i++) {  
  474.     count = fwrite(w[i], sizeof(double), X.blocksizes[i], f);  
  475.     check(count == X.blocksizes[i]);  
  476.   }  
  477.   fclose(f);  
  478.   
  479.   // score examples  
  480.   // ---所有的样本都的得分,没有乘以 label y   
  481.   printf("Scoring\n");  
  482.   double *s = score(X, examples, num, w);  
  483.   
  484.   // ---------Write info file-------------  
  485.   printf("Writing info file\n");  
  486.   f = fopen(inffile, "w");  
  487.   check(f != NULL);  
  488.   for (int i = 0; i < num; i++) {  
  489.     int len = ((int *)examples[i])[0];  
  490.     // label, score, unique flag  
  491.     count = fprintf(f, "%d\t%f\t%d\n", ((int *)examples[i])[1], s[i],   
  492.                     (int)examples[i][sizeof(int)+len]);  
  493.     check(count > 0);  
  494.   }  
  495.   fclose(f);  
  496.     
  497.   printf("Freeing memory\n");  
  498.   for (int i = 0; i < X.numblocks; i++) {  
  499.     free(w[i]);  
  500.     free(lb[i]);  
  501.   }  
  502.   free(w);  
  503.   free(lb);  
  504.   free(s);  
  505.   for (int i = 0; i < num; i++)  
  506.     free(examples[i]);  
  507.   free(examples);  
  508.   free(sorted);  
  509.   free(X.x);  
  510.   free(X.blocksizes);  
  511.   free(X.regmult);  
  512.   free(X.learnmult);  
  513.   
  514.   return 0;  
  515. }  

你可能感兴趣的:(DPM(Defomable Parts Model) 源码分析-训练(三))