这代码各种看不懂,各种给跪,当工具用吧。。
主函数:
main.cpp
1 #include "rbmpredictdata.h" 2 #include "rbmdata.h" 3 #include "rbm.h" 4 #include "rbmparallel.h" 5 #include6 #include <string> 7 #include 8 #include 9 10 int main(int argc, char* argv[]) { 11 parseArgs(argc, argv); //第1步、这一步取得三个值 Hidden层数、trainfilename、testfilename 12 srand(seed); 13 printConfig(); //第2步、打印初始化的数据 14 15 RbmPredictData predictData; //构造出对象 16 safeLoad(predictData, testFilename); //第3步、预测数据化成矩阵形式 17 cout << "Done loading test data" << flush; 18 19 RbmData data; 20 safeLoad(data, trainFilename); //步骤和第3步完全一样,这是训练集 21 cout << "\rDone loading data. " << endl; 22 23 Rbm* r = NULL; 24 if (parallel) 25 r = new RbmParallel(nThreads, data, nHidden); 26 else 27 r = new Rbm(data, nHidden); //第4步、构造rbm,初始化各种W,hb,vb。查看rbm.cpp 28 r->momentum = initialMomentum; 29 r->hBiasLearnRate = hLearn; 30 r->vBiasLearnRate = vLearn; 31 r->WlearnRate = wLearn; 32 r->weightDecay = wCost; 33 34 for (int i = 1; i <= nEpochs; i++) { 35 int increment = extractIncrements(i); //浮云 36 if (increment) { 37 r->T += increment; 38 cout << "\tT = " << r->T << endl; 39 } 40 if (i == finalMomentumStart) 41 r->momentum = finalMomentum; 42 43 r->performEpoch(data); //第5步、最关键的步骤,查看rbm.cpp 44 45 if (!epochsToSaveAt.empty() && epochsToSaveAt.front() == i) { 46 epochsToSaveAt.pop_front(); 47 stringstream t; 48 t << savePrefix << i; 49 double d = r->predict(data, predictData, t.str()); 50 cout << i << ": saving to " << t.str() << "(" << d << ")" << endl; 51 } else if (predictAlways) { 52 double d = r->predict(data, predictData); //第6步、预测,查看rbm.cpp 53 cout << i << ": " << d << endl; 54 } else { 55 cout << i << ": omitting prediction" << endl; 56 } 57 } 58 59 return 0; 60 } 61 62 //第1步、取参数 63 void parseArgs(int argc, char* argv[]) { 64 int current = 1; 65 // TODO: make more user friendly in terms of error handling. 66 // and coding style leaves something to be desired... 67 while (current < argc) { 68 if (strcmp("-h", argv[current]) == 0) { 69 printf(helpString, argv[0]); 70 exit(0); 71 } else if (strcmp("-d", argv[current]) == 0) { 72 cout << "Defaults: " << endl; 73 printConfig(); 74 exit(0); 75 } else if (strcmp("-v", argv[current]) == 0) { 76 vLearn = atof(argv[current + 1]); 77 current += 1; 78 } else if (strcmp("-H", argv[current]) == 0) { 79 hLearn = atof(argv[current + 1]); 80 current += 1; 81 } else if (strcmp("-w", argv[current]) == 0) { 82 wLearn = atof(argv[current + 1]); 83 current += 1; 84 } else if (strcmp("-n", argv[current]) == 0) { 85 nHidden = atoi(argv[current + 1]); 86 current += 1; 87 } else if (strcmp("-i", argv[current]) == 0) { 88 initialMomentum = atof(argv[current + 1]); 89 current += 1; 90 } else if (strcmp("-m", argv[current]) == 0) { 91 finalMomentum = atof(argv[current + 1]); 92 finalMomentumStart = atoi(argv[current + 2]); 93 current += 2; 94 } else if (strcmp("-e", argv[current]) == 0) { 95 nEpochs = atoi(argv[current + 1]); 96 current += 1; 97 } else if (strcmp("-c", argv[current]) == 0) { 98 wCost = atof(argv[current + 1]); 99 current += 1; 100 } else if (strcmp("-t", argv[current]) == 0) { 101 do { 102 current += 1; 103 tIncrements.push_back(atoi(argv[current])); 104 } while (current + 1 < argc && 105 '0' <= argv[current + 1][0] && argv[current + 1][0] <= '9'); 106 } else if (strcmp("-s", argv[current]) == 0) { 107 current += 1; 108 seed = atoi(argv[current]); 109 } else if (strcmp("--save", argv[current]) == 0) { 110 current += 1; 111 savePrefix = argv[current]; 112 do { 113 current += 1; 114 epochsToSaveAt.push_back(atoi(argv[current])); 115 } while (current + 1 < argc && 116 '0' <= argv[current + 1][0] && argv[current + 1][0] <= '9'); 117 } else if (strcmp("--never", argv[current]) == 0) { 118 predictAlways = false; 119 } else if (argc - current == 2) { 120 trainFilename = argv[current]; 121 } else if (argc - current == 1) { 122 testFilename = argv[current]; 123 } else { 124 cerr << "ERROR: Unknown option: " << argv[current] << endl 125 << "Exiting now." << endl; 126 exit(1); 127 } 128 current += 1; 129 } 130 } 131 132 //第2步、打印参数 133 void printConfig() { 134 cout << "Learning rates:" << endl 135 << " visible " << vLearn << endl 136 << " hidden " << hLearn << endl 137 << " weights " << wLearn << endl 138 << " cost " << wCost << endl; 139 cout << "Hidden nodes: " << nHidden << endl; 140 if (initialMomentum != 0.0 && finalMomentum != 0.0) { 141 cout << "Momentum: " << endl 142 << " initial " << initialMomentum << endl 143 << " final " << finalMomentum << endl 144 << " start " << finalMomentumStart << endl; 145 } 146 if (tIncrements.size() > 0) { 147 cout << "T-increment on: "; 148 for (unsigned i = 0; i < tIncrements.size(); i++) { 149 cout << tIncrements[i] << ' '; 150 } 151 cout << endl; 152 } 153 if (epochsToSaveAt.size() > 0) { 154 cout << "Saving at epochs:"; 155 for (unsigned i = 0; i < epochsToSaveAt.size(); i++) { 156 cout << epochsToSaveAt[i] << ' '; 157 } 158 cout << endl; 159 cout << "Save prefix: " << savePrefix << endl; 160 } 161 cout << "Datasets: " << endl 162 << " train " << trainFilename << endl 163 << " test " << testFilename << endl; 164 cout << "Epochs: " << nEpochs << endl; 165 cout << "Random seed: " << seed << endl; 166 cout << "Predict always: " << (predictAlways? "yes" : "no") << endl; 167 cout << endl; 168 } 169 170 //第3步、这会调用rbmpredictdata.cpp的safeLoad函数,步骤3.1。 171 void safeLoad(RbmData& p, const string& fname) { 172 ifstream f(fname.c_str()); 173 if (!f) { 174 cerr << "ERROR: " << fname << " could not be opened." << endl; 175 cerr << "Exiting now." << endl; 176 exit(1); 177 } 178 p.loadTsv(f); 179 f.close(); 180 } 181 182 char helpString[] = 183 "Usage: %s [arguments] " 184 "\n" 185 "Arguments:\n" 186 " -h Print Help (this message) and exit\n" 187 " -v \n Set visible bias learning rate to " 188 " -H\n Set hidden bias learning rate to " 189 " -w\n Set weight learning rate to " 190 " -c\n Set weight-cost coefficient to " 191 " -n\n Use " 192 " -ihidden nodes\n Set initial momentum to " 193 " -m\n " 194 " -e Set final momentum to at epoch \n Perform " 195 " -tepochs\n [ " 196 " -d Print the defaults and exit\n" 197 " -s...] Increase T by one at epoch , ...\n Set the random seed to " 198 " --save\n Save the predictions of the model after epoch\n " 199 "[...] " 200 " --never Do not predict, unless specified to save.\n" 201 "\n" 202 "Example usage:\n" 203 " ./rbm -t 5 5 7 train.dat test.dat\n" 204 " Set T to 3 at epoch 5, and to 4 at epoch 7.\n" 205 " ./rbm --never --save mypredfile 10 20 train.dat test.dat\n" 206 " Generates two files: mypredfile10.dat and mypredfile20.dat\n"; 207 208 string trainFilename, testFilename, savePrefix; 209 float vLearn = 0.005, hLearn = 0.005, wLearn = 0.005, wCost = 0.005; 210 int nHidden = 100, nEpochs = 40; 211 float finalMomentum = 0.0, initialMomentum = 0.0; 212 int finalMomentumStart = 5; 213 deque<int> tIncrements; 214 deque<int> epochsToSaveAt; 215 int seed = 1; 216 bool predictAlways = true; 217 bool parallel = false; 218 int nThreads = 32; 219 220 int extractIncrements(int i) { 221 int increment = 0; 222 while (tIncrements.size() != 0 && tIncrements.front() == i) { 223 increment += 1; 224 tIncrements.pop_front(); 225 } 226 return increment; 227 }to file: + .dat\n
大头来了
rbm.h
1 #ifndef RBM_H 2 #define RBM_H 3 4 #include "rbmdata.h" 5 #include "rbmpredictdata.h" 6 #include7 #include 8 #include 9 10 using namespace Eigen; 11 using namespace std; 12 13 typedef Matrix<bool, 1, Dynamic> RowVectorXb; 14 15 class Rbm { 16 public: 17 Rbm(const RbmData& data, int nHidden); 18 19 virtual void performEpoch(const RbmData& data); 20 21 virtual double predict( 22 const RbmData& data, 23 const RbmPredictData& predictData, 24 const string& filename); 25 26 virtual double predict( 27 const RbmData& data, 28 const RbmPredictData& predictData); 29 30 virtual double predict( 31 const RbmData& data, 32 const RbmPredictData& predictData, 33 ostream& predictStream); 34 35 static float hBiasLearnRate; 36 static float vBiasLearnRate; 37 static float WlearnRate; 38 static float weightDecay; 39 static float momentum; 40 static int T; 41 private: 42 Rbm(); 43 44 void negActivation(const RowVectorXf& h0states); 45 46 void gibbsSample(); 47 48 void normalizedNegActivation(const RowVectorXf& h0states); 49 void softmax(const RowVectorXf& h0states); 50 51 void initVisibleBias(const RbmData& data); 52 53 void applyMomentum(const RbmData& data, int user); 54 void selectWeights(const RbmData& data, int rangeStart, int rangeEnd); 55 56 public: // These are public so RbmParallel can use them 57 MatrixXf W, Wsel, Wmomentum; 58 RowVectorXf vBias, vBiasSel, vBiasMomentum, hBias, hBiasMomentum; 59 60 void performEpoch(const RbmData& data, int userStart, int userEnd); 61 62 private: 63 RowVectorXf h0probs, hTProbs; 64 RowVectorXb h0states; 65 66 RowVectorXf negData; 67 68 MatrixXf posProds, negProds; 69 70 int nHidden; 71 int nClasses; 72 }; 73 74 #endif
rbm.cpp
1 #include "rbm.h" 2 #include3 #include 4 #include 5 #include 6 #include 7 8 float Rbm::hBiasLearnRate = 0.001; 9 float Rbm::vBiasLearnRate = 0.008; 10 float Rbm::WlearnRate = 0.0006; 11 float Rbm::weightDecay = 0.0001; 12 float Rbm::momentum = 0.5; 13 int Rbm::T = 1; 14 15 16 //第4步、初始化w,hb,vb 17 //W = N(movie) * K(评分类数) * M(隐层节点数) 18 //vb = N(movie) * K(评分类数) 19 //hb = M(隐层节点数) 20 Rbm::Rbm(const RbmData& data, int nHidden) 21 : nHidden(nHidden), nClasses(data.nClasses) 22 { 23 W.setRandom(data.nMovies * nClasses, nHidden); 24 W.array() *= 0.01; 25 vBias.setZero(data.nMovies * nClasses); 26 initVisibleBias(data); //可见层的bias初始化。 27 hBias.setZero(nHidden); 28 } 29 30 31 //4.1所建立的movies和ratings根据用户id进行匹配 32 33 void Rbm::initVisibleBias(const RbmData& data) { 34 MatrixXi totals = MatrixXi::Zero(1, data.nMovies); 35 int nUsers = data.range.size() - 1; 36 for (int i = 0; i < nUsers; i++) { 37 38 //segment根据data.range(i), data.range(i + 1)也就是每行的userid数建立矩阵 39 const auto& movies = data.movies.segment( 40 data.range(i), data.range(i + 1) - data.range(i)); 41 const auto& ratings = data.ratings.segment( 42 data.range(i), data.range(i + 1) - data.range(i)); 43 44 //表示userid=0有多少个 1有多少。。。且和每行的movies、ratings相等 45 int amountOfRatings = data.range(i + 1) - data.range(i); 46 for (int r = 0; r < amountOfRatings; r++) { 47 vBias(movies(r)) += ratings(r); 48 int exactMovie = movies(r) / nClasses; 49 totals(exactMovie) += ratings(r); 50 } 51 } 52 for (int m = 0; m < totals.size(); m++) { 53 for (int c = 0; c < nClasses; c++) { 54 if (vBias(m * nClasses + c) != 0) { 55 vBias(m * nClasses + c) /= totals(m); 56 vBias(m * nClasses + c) = 57 log(vBias(m * nClasses + c)); 58 } 59 } 60 } 61 } 62 63 64 //第5步、坑爹的又调用下面的 65 void Rbm::performEpoch(const RbmData& data) { 66 performEpoch(data, 0, data.range.size() - 1); 67 } 68 69 //第5.1步、开刀 70 void Rbm::performEpoch(const RbmData& data, int userStart, int userEnd) { 71 vector<int> randomizedIds(userEnd - userStart, 0); //定义了End-Start个0元素 72 for (int i = userStart; i < userEnd; i++) 73 randomizedIds[i - userStart] = i; //根据传进来的值知道这个相当是userid个数 74 random_shuffle(randomizedIds.begin(), randomizedIds.end()); //打乱顺序 75 76 for (unsigned i = 0; i < randomizedIds.size(); i++) { 77 int rangeStart = data.range(randomizedIds[i]); 78 int rangeEnd = data.range(randomizedIds[i] + 1); // end is exclusive 79 int rangeLength = rangeEnd - rangeStart; //用户id为i的那一行有几个电影评分了*5; 80 81 selectWeights(data, rangeStart, rangeEnd); //寻找和userid匹配的W 82 const auto& visData = data.ratings.segment(rangeStart, rangeLength); //这个userid矩阵化 83 84 h0probs = 1 / (1 + (-visData*Wsel - hBias).array().exp()); //h0 85 h0states = h0probs.array() > 86 (h0probs.Random(h0probs.size()).array() + 1) / 2; //? 87 negActivation(h0states.cast<float>()); //? 88 89 hTProbs = 1 / (1 + (-negData*Wsel - hBias).array().exp()); //h1 90 for (int t = 1; t < T; t++) 91 gibbsSample(); //gibbs 92 93 posProds.noalias() = visData.transpose() * h0probs; 94 negProds.noalias() = negData.transpose() * hTProbs; 95 96 if (i > 0) 97 applyMomentum(data, randomizedIds[i - 1]); //userid>0 98 99 100 //下面各种更新参数 101 hBiasMomentum.noalias() = hBiasLearnRate * (h0probs - hTProbs); 102 hBias.noalias() += hBiasMomentum; 103 vBiasMomentum.noalias() = vBiasLearnRate * (visData - negData); 104 for (int r = rangeStart; r < rangeEnd; r++) 105 vBias(data.movies(r)) += vBiasMomentum(r - rangeStart); 106 107 Wmomentum.noalias() = posProds - negProds; 108 for (int r = 0; r < rangeLength; r++) 109 W.row(data.movies(r + rangeStart)).noalias() += 110 WlearnRate * (Wmomentum.row(r) - weightDecay*Wsel.row(r)); 111 } 112 } 113 114 void Rbm::negActivation(const RowVectorXf& h0states) { 115 softmax(h0states); 116 } 117 118 void Rbm::gibbsSample() { 119 h0states = 120 hTProbs.array() > (hTProbs.Random(hTProbs.size()).array() + 1) / 2; 121 negActivation(h0states.cast<float>()); 122 hTProbs = 1 / (1 + (-negData*Wsel - hBias).array().exp()); 123 } 124 125 void Rbm::normalizedNegActivation(const RowVectorXf& h0states) { 126 softmax(h0states); 127 } 128 129 void Rbm::softmax(const RowVectorXf& h0states) { 130 negData = (h0states*Wsel.transpose() + vBiasSel); 131 for (int m = 0; m < negData.size(); m += nClasses) { 132 negData.segment(m, nClasses).array() -= 133 negData.segment(m, nClasses).maxCoeff(); //userid=i那行减去其最大值 134 } 135 negData.array() = negData.array().exp(); 136 for (int m = 0; m < negData.size(); m += nClasses) { 137 negData.segment(m, nClasses).array() /= 138 negData.segment(m, nClasses).sum(); //userid=i那行减去其最大值,不知何用意? 139 } 140 } 141 142 void Rbm::selectWeights(const RbmData& data, int rangeStart, int rangeEnd) { 143 int rangeLength = rangeEnd - rangeStart; 144 Wsel.resize(rangeLength, nHidden); 145 vBiasSel.resize(rangeLength); 146 for (int r = rangeStart; r < rangeEnd; r++) { 147 Wsel.row(r - rangeStart).noalias() = W.row(data.movies(r)); 148 vBiasSel(r - rangeStart) = vBias(data.movies(r)); 149 } 150 } 151 152 void Rbm::applyMomentum(const RbmData& data, int user) { 153 if (momentum == 0.0) return; 154 int rangeStart = data.range(user); 155 int rangeEnd = data.range(user + 1); // end is exclusive 156 int rangeLength = rangeEnd - rangeStart; 157 158 hBias.noalias() += momentum * hBiasMomentum; 159 for (int r = rangeStart; r < rangeEnd; r++) 160 vBias(data.movies(r)) += momentum * vBiasMomentum(r - rangeStart); 161 162 for (int r = 0; r < rangeLength; r++) { 163 W.row(data.movies(r + rangeStart)).noalias() += 164 momentum * Wmomentum.row(r); 165 } 166 } 167 168 double Rbm::predict(const RbmData& data, const RbmPredictData& predictData, 169 const string& fname) { 170 ofstream out(fname.c_str()); 171 double d = predict(data, predictData, out); 172 out.close(); 173 return d; 174 } 175 176 //第6步、又特么调用下面的 177 double Rbm::predict(const RbmData& data, const RbmPredictData& predictData) { 178 stringstream dontcare; //输入流,传入参数和目标对象类型自动推导 179 return predict(data, predictData, dontcare); 180 } 181 182 //第6.1步。主要思想如下 183 //由之前训练重构的数据和预测数据相减算rmse 184 double Rbm::predict(const RbmData& data, const RbmPredictData& predictData, 185 ostream& predictStream) { 186 double rmse = 0.0; 187 int predictCount = 0; 188 for (unsigned i = 0; i < predictData.userIds.size(); i++) { 189 int userId = predictData.userIds[i]; 190 int rangeStart = data.range(userId); 191 int rangeEnd = data.range(userId + 1); // end is exclusive 192 selectWeights(data, rangeStart, rangeEnd); 193 const auto& visData = 194 data.ratings.segment(rangeStart, rangeEnd - rangeStart); 195 h0probs = 1 / (1 + (-visData*Wsel - hBias).array().exp()); //这里又不知道在干嘛。 196 197 rangeStart = predictData.range(i); 198 rangeEnd = predictData.range(i + 1); 199 selectWeights(predictData, rangeStart, rangeEnd); 200 normalizedNegActivation(h0probs); 201 const auto& actualData = 202 predictData.ratings.segment(rangeStart, rangeEnd - rangeStart); //预测集ratings矩阵化 203 204 for (int r = 0; r < actualData.size(); r += nClasses) { 205 float actual = 0; 206 float predicted = 0; 207 for (int c = 0; c < nClasses; c++) { 208 actual += actualData(r + c) * (c + 1); //把5维的0 1 转化成评分数1~k 209 predicted += negData(r + c) * (c + 1); 210 } 211 predictStream << predicted << endl; 212 float t = actual - predicted; 213 rmse += t * t; 214 } 215 predictCount += actualData.size() / nClasses; 216 } 217 return sqrt(rmse / predictCount); 218 }
rbmparallel.h
1 #ifndef RBM_PARALLEL_H 2 #define RBM_PARALLEL_H 3 #include "rbm.h" 4 #include "rbmdata.h" 5 6 class RbmParallel : public Rbm { 7 public: 8 RbmParallel(int nThreads, const RbmData& data, int nHidden); 9 10 void performEpoch(const RbmData& data); 11 12 private: 13 RbmParallel(); 14 15 void startEpochs(const RbmData& data, int batchStart, int batchSize); 16 void joinExecution(); 17 void updateWeights(); 18 void synchronizeWeights(); 19 20 static void performEpochInThread( 21 Rbm& r, const RbmData& data, int userStart, int userEnd); 22 23 int nThreads; 24 int nUsers; 25 vectorsubRbms; 26 vector threads; 27 }; 28 29 #endif
rbmparallel.cpp
1 #include "rbmparallel.h" 2 3 RbmParallel::RbmParallel(int nThreads, const RbmData& data, int nHidden) : 4 Rbm(data, nHidden), 5 nThreads(nThreads), 6 nUsers(data.range.size() - 1), 7 subRbms(nThreads - 1, Rbm(data, nHidden)) 8 { 9 synchronizeWeights(); 10 } 11 12 void RbmParallel::performEpoch(const RbmData& data) { 13 int batchSize = nUsers; 14 for (int batchStart = 0; batchStart < nUsers; batchStart += batchSize) { 15 startEpochs(data, batchStart, batchSize); 16 joinExecution(); 17 updateWeights(); 18 } 19 } 20 21 void RbmParallel::startEpochs( 22 const RbmData& data, int batchStart, int batchSize) { 23 threads.clear(); 24 int usersPerStep = batchSize / nThreads; 25 int userStart = batchStart; 26 for (int i = 0; i < nThreads - 1; i++) { 27 threads.push_back( 28 thread(&RbmParallel::performEpochInThread, 29 ref(subRbms[i]), ref(data), 30 userStart, userStart + usersPerStep)); 31 userStart += usersPerStep; 32 } 33 Rbm::performEpoch(data, userStart, nUsers); 34 } 35 36 void RbmParallel::performEpochInThread( 37 Rbm& r, const RbmData& data, int userStart, int userEnd) { 38 r.performEpoch(data, userStart, userEnd); 39 } 40 41 void RbmParallel::joinExecution() { 42 for (auto it = threads.begin(); it != threads.end(); it++) { 43 it->join(); 44 } 45 } 46 47 void RbmParallel::updateWeights() { 48 float updateFactor = 1.0 / nThreads; 49 W *= updateFactor; 50 vBias *= updateFactor; 51 hBias *= updateFactor; 52 for (auto it = subRbms.begin(); it != subRbms.end(); it++) { 53 W += it->W * updateFactor; 54 vBias += it->vBias * updateFactor; 55 hBias += it->hBias * updateFactor; 56 } 57 synchronizeWeights(); 58 } 59 60 void RbmParallel::synchronizeWeights() { 61 for (auto it = subRbms.begin(); it != subRbms.end(); it++) { 62 it->W = W; 63 it->hBias = hBias; 64 it->vBias = vBias; 65 } 66 }
基础太差,看了一个星期,还是看不懂,伤不起。