windows下DPM经验2

原文:http://blog.sciencenet.cn/blog-261330-663183.html

接着昨天的继续

昨天吧demo()跑通了,今天我们继续修改训练部分。

同样参看了pozen同学的博客。

1、首先下载voc的数据库和相应的VOCdevkit。(注意吧数据也放在VOCdevkit的目录中)

2、修改global.m文件中的文件路径(根据自己需求和自己缩放位置修改)

3、根据pozen同学的说明修改了一些文件,还有unix换成了system的命令,一些命令换成windows下的命令(参考前一篇文章)。

说下我自己遇到的问题:

4、procid.m文件中的“/”修改为“\"。因为window下目录与linux下的差异。

5、还有就是learn.cpp的编译了。

遇到的问题有:

1、srand48 和drand48 在windows下没有,根据原理自己编了一个。不对请指出:

[html]  view plain copy
  1. #define MNWZ 0x100000000    
  2. #define ANWZ 0x5DEECE66D    
  3. #define CNWZ 0xB16   
  4. #define INFINITY 0xFFFFFFFFF  
  5.   
  6. int labelsize;  
  7. int dim;  
  8.   
  9. static unsigned long long seed = 1;  
  10.   
  11. double drand48(void)    
  12. {    
  13.     seed = (ANWZ * seed + CNWZ) & 0xFFFFFFFFFFFFLL;    
  14.     unsigned int x = seed >> 16;    
  15.     return  ((double)x / (double)MNWZ);       
  16. }  
  17.    
  18. //static unsigned long long seed = 1;  
  19.   
  20. void srand48(unsigned int i)    
  21. {    
  22.     seed  = (((long long int)i) << 16) | rand();    
  23. }  

2、INFINITY(linux下无穷大的标记),在windows下没有这个标志,因为是double型的数据

于是我定义了:#define INFINITY 0xFFFFFFFFF(不知道对不对,运行没错误)。

3、

[html]  view plain copy
  1. string filepath = string(logdir) + "/learnlog/" + string(logtag) + ".log";  

一直报错,最后我吧#include <string.h>改成了#include <string>就没问题了。应该说搜索路径的问题。

之后就可以运行了。如果不确定,大家可以先运行matlab,生成需要的一些文件,然后通过输入命令行去单步调试learn.cpp文件。

源代码中:readme中描述:

1. Download and install the 2006/2007/2008 PASCAL VOC devkit and dataset.

   (you should set VOCopts.testset='test' in VOCinit.m)

2. Modify 'globals.m' according to your configuration.

3. Run 'make' to compile learn.cc, the LSVM gradient descent code.

   (Run from a shell, not Matlab.)

4. Start matlab.

5. Run the 'compile' script to compile the helper functions.

   (you may need to edit compile.m to use a different convolution 

    routine depending on your system)

6. Use the 'pascal' script to train and evaluate a model. 

example:

> pascal('person', 3);   % train and evaluate a 6 component person model

这里贴出learn.cpp最后成功运行的代码:

[html]  view plain copy
  1. #include <stdio.h>  
  2. #include <stdlib.h>  
  3. #include <string>  
  4. #include <math.h>  
  5. #include <time.h>  
  6. #include <errno.h>  
  7. #include <fstream>  
  8. #include <iostream>  
  9. //#include "stdafx.h"  
  10.   
  11. using namespace std;  
  12.   
  13. /*  
  14.  * Optimize LSVM objective function via gradient descent.  
  15.  *  
  16.  * We use an adaptive cache mechanism.  After a negative example  
  17.  * scores beyond the margin multiple times it is removed from the  
  18.  * training set for a fixed number of iterations.  
  19.  */  
  20.   
  21. // Data File Format  
  22. // EXAMPLE*  
  23. //   
  24. // EXAMPLE:  
  25. //  long label          ints  
  26. //  blocks              int  
  27. //  dim                 int  
  28. //  DATA{blocks}  
  29. //  
  30. // DATA:  
  31. //  block label         float  
  32. //  block data          floats  
  33. //  
  34. // Internal Binary Format  
  35. //  len           int (byte length of EXAMPLE)  
  36. //  EXAMPLE       <see above>  
  37. //  unique flag   byte  
  38.   
  39. // number of iterations  
  40.   
  41. /*#ifndef DRAND48_H    
  42. #define DRAND48_H    
  43.     
  44. #include <stdlib.h>  */  
  45.     
  46. //#define m 0x100000000LL    
  47. //#define a 0x5DEECE66DLL    
  48. //static unsigned long long seed = 1;  
  49. //#endif  
  50. //#define Infinity 1.0+308  
  51.   
  52.   
  53. #define ITER 10e6  
  54.   
  55. // minimum # of iterations before termination  
  56. #define MIN_ITER 5e6  
  57.   
  58. // convergence threshold  
  59. #define DELTA_STOP 0.9995  
  60.   
  61. // number of times in a row the convergence threshold  
  62. // must be reached before stopping  
  63. #define STOP_COUNT 5  
  64.   
  65. // small cache parameters  
  66. #define INCACHE 25  
  67. #define MINWAIT (INCACHE+25)  
  68. #define REGFREQ 20  
  69.   
  70. // error checking  
  71. #define check(e) \  
  72. (e ? (void)0 : (printf("%s:%u error: %s\n%s\n", __FILE__, __LINE__, #e, strerror(errno)), exit(1)))  
  73.   
  74. // number of non-zero blocks in example ex  
  75. #define NUM_NONZERO(ex) (((int *)ex)[labelsize+1])  
  76.   
  77. // float pointer to data segment of example ex  
  78. #define EX_DATA(ex) ((float *)(ex + sizeof(int)*(labelsize+3)))  
  79.   
  80. // class label (+1 or -1) for the example  
  81. #define LABEL(ex) (((int *)ex)[1])  
  82.   
  83. // block label (converted to 0-based index)  
  84. #define BLOCK_IDX(data) (((int)data[0])-1)  
  85.   
  86. // set to 0 to use max-component L2 regularization  
  87. // set to 1 to use full model L2 regularization  
  88. #define FULL_L2 0  
  89.   
  90. #define MNWZ 0x100000000    
  91. #define ANWZ 0x5DEECE66D    
  92. #define CNWZ 0xB16   
  93. #define INFINITY 0xFFFFFFFFF  
  94.   
  95. int labelsize;  
  96. int dim;  
  97.   
  98. static unsigned long long seed = 1;  
  99.   
  100. double drand48(void)    
  101. {    
  102.     seed = (ANWZ * seed + CNWZ) & 0xFFFFFFFFFFFFLL;    
  103.     unsigned int x = seed >> 16;    
  104.     return  ((double)x / (double)MNWZ);       
  105. }  
  106.    
  107. //static unsigned long long seed = 1;  
  108.   
  109. void srand48(unsigned int i)    
  110. {    
  111.     seed  = (((long long int)i) << 16) | rand();    
  112. }  
  113.    
  114. // comparison function for sorting examples   
  115. int comp(const void *a, const void *b) {  
  116.   // sort by extended label first, and whole example second...  
  117.   int c = memcmp(*((char **)a) + sizeof(int),   
  118.          *((char **)b) + sizeof(int),   
  119.          labelsize*sizeof(int));  
  120.   if (c)  
  121.     return c;  
  122.     
  123.   // labels are the same    
  124.   int alen = **((int **)a);  
  125.   int blen = **((int **)b);  
  126.   if (alen == blen)  
  127.     return memcmp(*((char **)a) + sizeof(int),   
  128.           *((char **)b) + sizeof(int),   
  129.           alen);  
  130.   return ((alen < blen) ? -1 : 1);  
  131. }  
  132.   
  133. // a collapsed example is a sequence of examples  
  134. struct collapsed {  
  135.   char **seq;  
  136.   int num;  
  137. };  
  138.   
  139. // the two node types in an AND/OR tree  
  140. enum node_type { OR, AND };  
  141.   
  142. // set of collapsed examples  
  143. struct data {  
  144.   collapsed *x;  
  145.   int num;  
  146.   int numblocks;  
  147.   int numcomponents;  
  148.   int *blocksizes;  
  149.   int *componentsizes;  
  150.   int **componentblocks;  
  151.   float *regmult;  
  152.   float *learnmult;  
  153. };  
  154.   
  155. // seed the random number generator with an arbitrary (fixed) value  
  156. void seed_rand() {  
  157.     srand48(3);  
  158.     //srand(3);  
  159. }  
  160.   
  161. static inline double min(double x, double y) { return (x <= y ? x : y); }  
  162. static inline double max(double x, double y) { return (x <= y ? y : x); }  
  163.   
  164. // compute the score of an example  
  165. static inline double ex_score(const char *ex, data X, double **w) {  
  166.   double val = 0.0;  
  167.   float *data = EX_DATA(ex);  
  168.   int blocks = NUM_NONZERO(ex);  
  169.   for (int j = 0; j < blocks; j++) {  
  170.     int b = BLOCK_IDX(data);  
  171.     data++;  
  172.     double blockval = 0;  
  173.     for (int k = 0; k < X.blocksizes[b]; k++)  
  174.       blockval += w[b][k] * data[k];  
  175.     data += X.blocksizes[b];  
  176.     val += blockval;  
  177.   }  
  178.   return val;  
  179. }  
  180.   
  181. // return the value of the object function.  
  182. // out[0] : loss on negative examples  
  183. // out[1] : loss on positive examples  
  184. // out[2] : regularization term's value  
  185. double compute_loss(double out[3], double C, double J, data X, double **w) {  
  186.   double loss = 0.0;  
  187. #if FULL_L2  
  188.   // compute ||w||^2  
  189.   for (int j = 0; j < X.numblocks; j++) {  
  190.     for (int k = 0; k < X.blocksizes[j]; k++) {  
  191.       loss += w[j][k] * w[j][k] * X.regmult[j];  
  192.     }  
  193.   }  
  194. #else  
  195.   // compute max norm^2 component  
  196.   for (int c = 0; c < X.numcomponents; c++) {  
  197.     double val = 0;  
  198.     for (int i = 0; i < X.componentsizes[c]; i++) {  
  199.       int b = X.componentblocks[c][i];  
  200.       double blockval = 0;  
  201.       for (int k = 0; k < X.blocksizes[b]; k++)  
  202.         blockval += w[b][k] * w[b][k] * X.regmult[b];  
  203.       val += blockval;  
  204.     }  
  205.     if (val > loss)  
  206.       loss = val;  
  207.   }  
  208. #endif  
  209.   loss *= 0.5;  
  210.   
  211.   // record the regularization term  
  212.   out[2] = loss;  
  213.   
  214.   // compute loss from the training data  
  215.   for (int l = 0; l <= 1; l++) {  
  216.     // which label subset to look at: -1 or 1  
  217.     int subset = (l*2)-1;  
  218.     double subsetloss = 0.0;  
  219.     for (int i = 0; i < X.num; i++) {  
  220.       collapsed x = X.x[i];  
  221.   
  222.       // only consider examples in the target subset  
  223.       char *ptr = x.seq[0];  
  224.       if (LABEL(ptr) != subset)  
  225.         continue;  
  226.   
  227.       // compute max over latent placements  
  228.       int M = -1;  
  229.       double V = -INFINITY;  
  230.       //double V = -NWZ;  
  231.       for (int m = 0; m < x.num; m++) {  
  232.         double val = ex_score(x.seq[m], X, w);  
  233.         if (val > V) {  
  234.           M = m;  
  235.           V = val;  
  236.         }  
  237.       }  
  238.   
  239.       // compute loss on max  
  240.       ptr = x.seq[M];  
  241.       int label = LABEL(ptr);  
  242.       double mult = C * (label == 1 ? J : 1);  
  243.       subsetloss += mult * max(0.0, 1.0-label*V);  
  244.     }  
  245.     loss += subsetloss;  
  246.     out[l] = subsetloss;  
  247.   }  
  248.   
  249.   return loss;  
  250. }  
  251.   
  252. // gradient descent  
  253. void gd(double C, double J, data X, double **w, double **lb, char *logdir, char *logtag) {  
  254.   ofstream logfile;  
  255.   string filepath = string(logdir) + "/learnlog/" + string(logtag) + ".log";  
  256.    
  257.   /*char* filepath;  
  258.   strcat(filepath,logdir);  
  259.   strcat(filepath,"/learnlog/");  
  260.   strcat(filepath,logtag);  
  261.   strcat(filepath,"/log");*/  
  262.     
  263.   logfile.open(filepath.c_str());  
  264.   //logfile.open(filepath);  
  265.   logfile.precision(14);  
  266.   logfile.setf(ios::fixed, ios::floatfield);  
  267.   
  268.   int num = X.num;  
  269.     
  270.   // state for random permutations  
  271.   int *perm = (int *)malloc(sizeof(int)*X.num);  
  272.   check(perm != NULL);  
  273.   
  274.   // state for small cache  
  275.   int *W = (int *)malloc(sizeof(int)*num);  
  276.   check(W != NULL);  
  277.   for (int j = 0; j < num; j++)  
  278.     W[j] = INCACHE;  
  279.   
  280.   double prev_loss = 1E9;  
  281.   
  282.   bool converged = false;  
  283.   int stop_count = 0;  
  284.   int t = 0;  
  285.   while (t < ITER && !converged) {  
  286.     // pick random permutation  
  287.     for (int i = 0; i < num; i++)  
  288.       perm[i] = i;  
  289.     for (int swapi = 0; swapi < num; swapi++) {  
  290.       int swapj = (int)(drand48()*(num-swapi)) + swapi;  
  291.       //int swapj = (int)(rand()*(num-swapi)) + swapi;  
  292.       int tmp = perm[swapi];  
  293.       perm[swapi] = perm[swapj];  
  294.       perm[swapj] = tmp;  
  295.     }  
  296.   
  297.     // count number of examples in the small cache  
  298.     int cnum = 0;  
  299.     for (int i = 0; i < num; i++)  
  300.       if (W[i] <= INCACHE)  
  301.     cnum++;  
  302.   
  303.     int numupdated = 0;  
  304.     for (int swapi = 0; swapi < num; swapi++) {  
  305.       // select example  
  306.       int i = perm[swapi];  
  307.   
  308.       // skip if example is not in small cache  
  309.       if (W[i] > INCACHE) {  
  310.     W[i]--;  
  311.     continue;  
  312.       }  
  313.   
  314.       collapsed x = X.x[i];  
  315.   
  316.       // learning rate  
  317.       double T = min(ITER/2.0, t + 10000.0);  
  318.       double rateX = cnum * C / T;  
  319.   
  320.       t++;  
  321.       if (t % 100000 == 0) {  
  322.         double info[3];  
  323.         double loss = compute_loss(info, C, J, X, w);  
  324.         double delta = 1.0 - (fabs(prev_loss - loss) / loss);  
  325.         logfile << t << "\t" << loss << "\t" << delta << endl;  
  326.         if (delta >= DELTA_STOP && t >= MIN_ITER) {  
  327.           stop_count++;  
  328.           if (stop_count > STOP_COUNT)  
  329.             converged = true;  
  330.         } else if (stop_count > 0) {  
  331.           stop_count = 0;  
  332.         }  
  333.         prev_loss = loss;  
  334.         printf("\r%7.2f%% of max # iterations "  
  335.                "(delta = %.5f; stop count = %d)",   
  336.                100*double(t)/double(ITER), max(delta, 0.0),   
  337.                STOP_COUNT - stop_count + 1);  
  338.     fflush(stdout);  
  339.         if (converged)  
  340.           break;  
  341.       }  
  342.         
  343.       // compute max over latent placements  
  344.       int M = -1;  
  345.       double V = -INFINITY;  
  346.       //double V = -NWZ;  
  347.       for (int m = 0; m < x.num; m++) {  
  348.     double val = ex_score(x.seq[m], X, w);  
  349.     if (val > V) {  
  350.       M = m;  
  351.       V = val;  
  352.     }  
  353.     }  
  354.         
  355.       char *ptr = x.seq[M];  
  356.       int label = LABEL(ptr);  
  357.       if (label * V < 1.0) {  
  358.         numupdated++;  
  359.     W[i] = 0;  
  360.     float *data = EX_DATA(ptr);  
  361.     int blocks = NUM_NONZERO(ptr);  
  362.     for (int j = 0; j < blocks; j++) {  
  363.       int b = BLOCK_IDX(data);  
  364.       double mult = (label > 0 ? J : -1) * rateX * X.learnmult[b];        
  365.       data++;  
  366.       for (int k = 0; k < X.blocksizes[b]; k++)  
  367.         w[b][k] += mult * data[k];  
  368.       data += X.blocksizes[b];  
  369.     }  
  370.       } else {  
  371.     if (W[i] == INCACHE)  
  372.           W[i] = MINWAIT + (int)(drand48()*50);  
  373.           //W[i] = MINWAIT + (int)(rand()*50);  
  374.     else  
  375.       W[i]++;  
  376.       }  
  377.   
  378.       // periodically regularize the model  
  379.       if (t % REGFREQ == 0) {  
  380.         // apply lowerbounds  
  381.         for (int j = 0; j < X.numblocks; j++)  
  382.           for (int k = 0; k < X.blocksizes[j]; k++)  
  383.             w[j][k] = max(w[j][k], lb[j][k]);  
  384.   
  385.         double rateR = 1.0 / T;  
  386.   
  387. #if FULL_L2   
  388.         // update model  
  389.         for (int j = 0; j < X.numblocks; j++) {  
  390.           double mult = rateR * X.regmult[j] * X.learnmult[j];  
  391.           mult = pow((1-mult), REGFREQ);  
  392.           for (int k = 0; k < X.blocksizes[j]; k++) {  
  393.             w[j][k] = mult * w[j][k];  
  394.           }  
  395.         }  
  396. #else  
  397.         // assume simple mixture model  
  398.         int maxc = 0;  
  399.         double bestval = 0;  
  400.         for (int c = 0; c < X.numcomponents; c++) {  
  401.           double val = 0;  
  402.           for (int i = 0; i < X.componentsizes[c]; i++) {  
  403.             int b = X.componentblocks[c][i];  
  404.             double blockval = 0;  
  405.             for (int k = 0; k < X.blocksizes[b]; k++)  
  406.               blockval += w[b][k] * w[b][k] * X.regmult[b];  
  407.             val += blockval;  
  408.           }  
  409.           if (val > bestval) {  
  410.             maxc = c;  
  411.             bestval = val;  
  412.           }  
  413.         }  
  414.         for (int i = 0; i < X.componentsizes[maxc]; i++) {  
  415.           int b = X.componentblocks[maxc][i];  
  416.           double mult = rateR * X.regmult[b] * X.learnmult[b];          
  417.           mult = pow((1-mult), REGFREQ);  
  418.           for (int k = 0; k < X.blocksizes[b]; k++)  
  419.             w[b][k] = mult * w[b][k];  
  420.         }  
  421. #endif  
  422.       }  
  423.     }  
  424.   }  
  425.   
  426.   if (converged)  
  427.     printf("\nTermination criteria reached after %d iterations.\n", t);  
  428.   else  
  429.     printf("\nMax iteration count reached.\n", t);  
  430.   
  431.   free(perm);  
  432.   free(W);  
  433.   logfile.close();  
  434. }  
  435.   
  436. // score examples  
  437. double *score(data X, char **examples, int num, double **w) {  
  438.   double *s = (double *)malloc(sizeof(double)*num);  
  439.   check(s != NULL);  
  440.   for (int i = 0; i < num; i++)  
  441.     s[i] = ex_score(examples[i], X, w);  
  442.   return s;    
  443. }  
  444.   
  445. // merge examples with identical labels  
  446. void collapse(data *X, char **examples, int num) {  
  447.   collapsed *x = (collapsed *)malloc(sizeof(collapsed)*num);  
  448.   check(x != NULL);  
  449.   int i = 0;  
  450.   x[0].seq = examples;  
  451.   x[0].num = 1;  
  452.   for (int j = 1; j < num; j++) {  
  453.     if (!memcmp(x[i].seq[0]+sizeof(int), examples[j]+sizeof(int),   
  454.         labelsize*sizeof(int))) {  
  455.       x[i].num++;  
  456.     } else {  
  457.       i++;  
  458.       x[i].seq = &(examples[j]);  
  459.       x[i].num = 1;  
  460.     }  
  461.   }  
  462.   X->x = x;  
  463.   X->num = i+1;    
  464. }  
  465.   
  466. int main(int argc, char **argv) {    
  467.   seed_rand();  
  468.   int count;  
  469.   data X;  
  470.   
  471.   // command line arguments  
  472.   check(argc == 12);  
  473.   double C = atof(argv[1]);  
  474.   double J = atof(argv[2]);  
  475.   char *hdrfile = argv[3];  
  476.   char *datfile = argv[4];  
  477.   char *modfile = argv[5];  
  478.   char *inffile = argv[6];  
  479.   char *lobfile = argv[7];  
  480.   char *cmpfile = argv[8];  
  481.   char *objfile = argv[9];  
  482.   char *logdir  = argv[10];  
  483.   char *logtag  = argv[11];  
  484.   
  485.   // read header file  
  486.   FILE *f = fopen(hdrfile, "rb");  
  487.   check(f != NULL);  
  488.   int header[3];  
  489.   count = fread(header, sizeof(int), 3, f);  
  490.   check(count == 3);  
  491.   int num = header[0];  
  492.   labelsize = header[1];  
  493.   X.numblocks = header[2];  
  494.   X.blocksizes = (int *)malloc(X.numblocks*sizeof(int));  
  495.   count = fread(X.blocksizes, sizeof(int), X.numblocks, f);  
  496.   check(count == X.numblocks);  
  497.   X.regmult = (float *)malloc(sizeof(float)*X.numblocks);  
  498.   check(X.regmult != NULL);  
  499.   count = fread(X.regmult, sizeof(float), X.numblocks, f);  
  500.   check(count == X.numblocks);  
  501.   X.learnmult = (float *)malloc(sizeof(float)*X.numblocks);  
  502.   check(X.learnmult != NULL);  
  503.   count = fread(X.learnmult, sizeof(float), X.numblocks, f);  
  504.   check(count == X.numblocks);  
  505.   check(num != 0);  
  506.   fclose(f);  
  507.   printf("%d examples with label size %d and %d blocks\n",  
  508.      num, labelsize, X.numblocks);  
  509.   printf("block size, regularization multiplier, learning rate multiplier\n");  
  510.   dim = 0;  
  511.   for (int i = 0; i < X.numblocks; i++) {  
  512.     dim += X.blocksizes[i];  
  513.     printf("%d, %.2f, %.2f\n", X.blocksizes[i], X.regmult[i], X.learnmult[i]);  
  514.   }  
  515.   
  516.   // read component info file  
  517.   // format: #components {#blocks blk1 ... blk#blocks}^#components  
  518.   f = fopen(cmpfile, "rb");  
  519.   count = fread(&X.numcomponents, sizeof(int), 1, f);  
  520.   check(count == 1);  
  521.   printf("the model has %d components\n", X.numcomponents);  
  522.   X.componentblocks = (int **)malloc(X.numcomponents*sizeof(int *));  
  523.   X.componentsizes = (int *)malloc(X.numcomponents*sizeof(int));  
  524.   for (int i = 0; i < X.numcomponents; i++) {  
  525.     count = fread(&X.componentsizes[i], sizeof(int), 1, f);  
  526.     check(count == 1);  
  527.     printf("component %d has %d blocks:", i, X.componentsizes[i]);  
  528.     X.componentblocks[i] = (int *)malloc(X.componentsizes[i]*sizeof(int));  
  529.     count = fread(X.componentblocks[i], sizeof(int), X.componentsizes[i], f);  
  530.     check(count == X.componentsizes[i]);  
  531.     for (int j = 0; j < X.componentsizes[i]; j++)  
  532.       printf(" %d", X.componentblocks[i][j]);  
  533.     printf("\n");  
  534.   }  
  535.   fclose(f);  
  536.   
  537.   // read examples  
  538.   f = fopen(datfile, "rb");  
  539.   check(f != NULL);  
  540.   printf("Reading examples\n");  
  541.   char **examples = (char **)malloc(num*sizeof(char *));  
  542.   check(examples != NULL);  
  543.   for (int i = 0; i < num; i++) {  
  544.     // we use an extra byte in the end of each example to mark unique  
  545.     // we use an extra int at the start of each example to store the   
  546.     // example's byte length (excluding unique flag and this int)  
  547.     //int buf[labelsize+2];  
  548.     int *buf = new int[labelsize+2];  
  549.         
  550.     count = fread(buf, sizeof(int), labelsize+2, f);  
  551.     check(count == labelsize+2);  
  552.     // byte length of an example's data segment  
  553.     int len = sizeof(int)*(labelsize+2) + sizeof(float)*buf[labelsize+1];  
  554.     // memory for data, an initial integer, and a final byte  
  555.     examples[i] = (char *)malloc(sizeof(int)+len+1);  
  556.     check(examples[i] != NULL);  
  557.     // set data segment's byte length  
  558.     ((int *)examples[i])[0] = len;  
  559.     // set the unique flag to zero  
  560.     examples[i][sizeof(int)+len] = 0;  
  561.     // copy label data into example  
  562.     for (int j = 0; j < labelsize+2; j++)  
  563.       ((int *)examples[i])[j+1] = buf[j];  
  564.     // read the rest of the data segment into the example  
  565.     count = fread(examples[i]+sizeof(int)*(labelsize+3), 1,   
  566.           len-sizeof(int)*(labelsize+2), f);  
  567.     check(count == len-sizeof(int)*(labelsize+2));  
  568.   
  569.     delete [] buf;  
  570.   }  
  571.   fclose(f);  
  572.   printf("done\n");  
  573.   
  574.   // sort  
  575.   printf("Sorting examples\n");  
  576.   char **sorted = (char **)malloc(num*sizeof(char *));  
  577.   check(sorted != NULL);  
  578.   memcpy(sorted, examples, num*sizeof(char *));  
  579.   qsort(sorted, num, sizeof(char *), comp);  
  580.   printf("done\n");  
  581.   
  582.   // find unique examples  
  583.   int i = 0;  
  584.   int len = *((int *)sorted[0]);  
  585.   sorted[0][sizeof(int)+len] = 1;  
  586.   for (int j = 1; j < num; j++) {  
  587.     int alen = *((int *)sorted[i]);  
  588.     int blen = *((int *)sorted[j]);  
  589.     if (alen != blen ||   
  590.     memcmp(sorted[i] + sizeof(int), sorted[j] + sizeof(int), alen)) {  
  591.       i++;  
  592.       sorted[i] = sorted[j];  
  593.       sorted[i][sizeof(int)+blen] = 1;  
  594.     }  
  595.   }  
  596.   int num_unique = i+1;  
  597.   printf("%d unique examples\n", num_unique);  
  598.   
  599.   // collapse examples  
  600.   collapse(&X, sorted, num_unique);  
  601.   printf("%d collapsed examples\n", X.num);  
  602.   
  603.   // initial model  
  604.   double **w = (double **)malloc(sizeof(double *)*X.numblocks);  
  605.   check(w != NULL);  
  606.   f = fopen(modfile, "rb");  
  607.   for (int i = 0; i < X.numblocks; i++) {  
  608.     w[i] = (double *)malloc(sizeof(double)*X.blocksizes[i]);  
  609.     check(w[i] != NULL);  
  610.     count = fread(w[i], sizeof(double), X.blocksizes[i], f);  
  611.     check(count == X.blocksizes[i]);  
  612.   }  
  613.   fclose(f);  
  614.   
  615.   // lower bounds  
  616.   double **lb = (double **)malloc(sizeof(double *)*X.numblocks);  
  617.   check(lb != NULL);  
  618.   f = fopen(lobfile, "rb");  
  619.   for (int i = 0; i < X.numblocks; i++) {  
  620.     lb[i] = (double *)malloc(sizeof(double)*X.blocksizes[i]);  
  621.     check(lb[i] != NULL);  
  622.     count = fread(lb[i], sizeof(double), X.blocksizes[i], f);  
  623.     check(count == X.blocksizes[i]);  
  624.   }  
  625.   fclose(f);  
  626.     
  627.   // train  
  628.   printf("Training\n");  
  629.   gd(C, J, X, w, lb, logdir, logtag);  
  630.   printf("done\n");  
  631.   
  632.   // save model  
  633.   printf("Saving model\n");  
  634.   f = fopen(modfile, "wb");  
  635.   check(f != NULL);  
  636.   for (int i = 0; i < X.numblocks; i++) {  
  637.     count = fwrite(w[i], sizeof(double), X.blocksizes[i], f);  
  638.     check(count == X.blocksizes[i]);  
  639.   }  
  640.   fclose(f);  
  641.   
  642.   // score examples  
  643.   printf("Scoring\n");  
  644.   double *s = score(X, examples, num, w);  
  645.   
  646.   // Write info file  
  647.   printf("Writing info file\n");  
  648.   f = fopen(inffile, "w");  
  649.   check(f != NULL);  
  650.   for (int i = 0; i < num; i++) {  
  651.     int len = ((int *)examples[i])[0];  
  652.     // label, score, unique flag  
  653.     count = fprintf(f, "%d\t%f\t%d\n", ((int *)examples[i])[1], s[i],   
  654.                     (int)examples[i][sizeof(int)+len]);  
  655.     check(count > 0);  
  656.   }  
  657.   fclose(f);  
  658.   
  659.   // compute loss and write it to a file  
  660.   double lossinfo[3];  
  661.   compute_loss(lossinfo, C, J, X, w);  
  662.   printf("Writing objective function info file\n");  
  663.   f = fopen(objfile, "w");  
  664.   count = fprintf(f, "%f\t%f\t%f", lossinfo[0], lossinfo[1], lossinfo[2]);  
  665.   check(count > 0);  
  666.   fclose(f);  
  667.     
  668.   printf("Freeing memory\n");  
  669.   for (int i = 0; i < X.numblocks; i++) {  
  670.     free(w[i]);  
  671.     free(lb[i]);  
  672.   }  
  673.   free(w);  
  674.   free(lb);  
  675.   free(s);  
  676.   for (int i = 0; i < num; i++)  
  677.     free(examples[i]);  
  678.   free(examples);  
  679.   free(sorted);  
  680.   free(X.x);  
  681.   free(X.blocksizes);  
  682.   free(X.regmult);  
  683.   free(X.learnmult);  
  684.   for (int i = 0; i < X.numcomponents; i++)  
  685.     free(X.componentblocks[i]);  
  686.   free(X.componentblocks);  
  687.   free(X.componentsizes);  
  688.   
  689.   return 0;  
  690. }  

你可能感兴趣的:(DPM)