项目中需要求解动态增长的线性方程组,方程个数会不断增加,需要使用高斯消元法来求解,每个方程添加完后,已知的变量值。
代码中分了三个类,表达式expression,方程 equation, 高斯消元求解器gauss_eliminator。
这段代码还是有问题,如果需要最新的,请访问我的github https://github.com/tfjiang/gauss_elimination
1 #ifndef GAUSS_ELIMINATION_H 2 #define GAUSS_ELIMINATION_H
3
4 #include <map>
5 #include <vector>
6 #include <list>
7 #include <deque>
9 #include <iostream>
10 #include <boost/unordered_map.hpp>
11 #include <boost/static_assert.hpp>
12 #include <boost/type_traits/is_same.hpp>
13 #include <zjucad/matrix/matrix.h> // a private matrix library
14
15 namespace jtf{ 16 namespace algorithm{ 17
18 /** 19 * @brief store an expression A * i, where A is the coeffieient 20 * 21 */
22 template <typename T>
23 class expression{ 24 public: 25 expression():index(-1),coefficient(0){} 26 expression(const size_t & index_, const T & coefficient_) 27 : index(index_), coefficient(coefficient_){} 28 size_t index; 29 T coefficient; 30 /** 31 * @brief operator < is used to determin which expression ha 32 * 33 * @param other input other expression 34 * @return bool 35 */
36 bool operator < (const expression<T> & b) const{ 37 return index < b.index; 38 } 39
40 /** 41 * @brief To determin whether coefficent of this expression is zeros 42 * 43 * @return bool 44 */
45 bool is_zero() const{ return fabs(coefficient) < 1e-6;} 46 }; 47
48 /** 49 * @brief make expression with node_idx and value 50 * 51 * @param node_idx 52 * @param value 53 * @return expression<T> 54 */
55 template <typename T>
56 expression<T> make_expression(const size_t & node_idx, const T & value) 57 { 58 expression<T> temp(node_idx, value); 59 return temp; 60 } 61
62 /** 63 * @brief equation class 64 * 65 */
66 template <typename T>
67 class equation{ 68 public: 69 typedef typename std::list<expression<T> >::const_iterator eq_const_iterator; 70 typedef typename std::list<expression<T> >::iterator eq_iterator; 71
72 equation():value_(static_cast<T>(0)){} 73
74 eq_const_iterator begin() const {return e_vec_.begin();} 75 eq_iterator begin(){return e_vec_.begin();} 76
77 eq_const_iterator end() const {return e_vec_.end();} 78 eq_iterator end(){return e_vec_.end();} 79
80 /** 81 * @brief used to standardizate the equation, sort the expressions and merge 82 * similar items and normalization to make the first item's coefficient 83 * equal 1 84 * 85 * @return int 86 */
87 int standardization(){ 88 sort_equation(); 89 merge_similar(); 90 normalization(); 91 return 0; 92 } 93
94 /** 95 * @brief update the equation with given node idx and value 96 * 97 * @param node_idx input node idx 98 * @param node_value input node value 99 * @return int 100 */
101 int update(const size_t & node_idx, const T & node_value); 102 /** 103 * @brief sort equation accronding to the expressions 104 * 105 * @return int 106 */
107 int sort_equation(){ 108 e_vec_.sort(); 109 return 0; 110 } 111
112 /** 113 * @brief merge similar items, WARNING. this function should be used after 114 * sorting. 115 * 116 * @return int 117 */
118 int merge_similar(); 119
120 /** 121 * @brief normalize the equations: to make the first expression's coefficient 122 * equals 1 123 * 124 * @return int 125 */
126 int normalization(); 127
128 /** 129 * @brief get the equation value 130 * 131 * @return const T 132 */
133 const T& get_value() const {return value_;} 134
135 /** 136 * @brief get the state of equation: 137 * if there are no expressions, 138 * if value != 0 return -1; error equation 139 * else return 0; cleared 140 * else 141 * if there is one expression return 1; calculated finial node 142 * else return 2; not finished 143 * 144 * @return int 145 */
146 int state() const; 147
148
149 /** 150 * @brief define the minus operation 151 * 152 * @param b 153 * @return equation<T> 154 */
155 equation<T> & operator -= (const equation<T> & b); 156
157 /** 158 * @brief define the output operation 159 * 160 * @param eq 161 * @return equation<T> 162 */
163 friend std::ostream& operator << (std::ostream &output, 164 const equation<T> &eq) 165 { 166 if(eq.e_vec_.empty()){ 167 output << "# ------ empty expressions with value = " << eq.value() << std::endl; 168 }else{ 169 output << "# ------ SUM coeff * index , val" << std::endl 170 << "# ------ "; 171 for(equation<T>::eq_const_iterator eqcit = eq.begin(); eqcit != eq.end();){ 172 const expression<T> & exp = *eqcit; 173 output << exp.coefficient << "*X" << exp.index; 174 ++eqcit; 175 if(eqcit == eq.end()) 176 output << " = "; 177 else
178 output << " + "; 179 } 180 output << eq.value() << std::endl; 181 } 182 return output; 183 } 184
185
186 /** 187 * @brief get the first expression idx 188 * 189 * @return size_t 190 */
191 size_t get_prim_idx() const { 192 return e_vec_.front().index; 193 } 194
195 int add_expression(const expression<T> & exp){ 196 e_vec_.push_back(exp); 197 return 0; 198 } 199 T& value() {return value_;} 200 const T& value() const {return value_;} 201 std::list<expression<T> > e_vec_; 202 private: 203 T value_; 204 }; 205
206 template <typename T>
207 int equation<T>::merge_similar(){ 208 typedef typename std::list<expression<T> >::iterator leit; 209 leit current = e_vec_.begin(); 210 leit next = current; 211 ++next; 212 while(next != e_vec_.end()){ 213 if(next->index > current->index) { 214 current = next++; 215 continue; 216 }else if(next->index == current->index){ 217 current->coefficient += next->coefficient; 218 e_vec_.erase(next++); 219 if(fabs(current->coefficient) < 1e-6){ 220 e_vec_.erase(current++); 221 ++next; 222 } 223 }else{ 224 std::cerr << "# [error] merge similar function should only be called "
225 << "after sorting." << std::endl; 226 return __LINE__; 227 } 228 } 229 return 0; 230 } 231
232 template <typename T>
233 int equation<T>::normalization() 234 { 235 const T coeff = e_vec_.front().coefficient; 236 if(fabs(coeff) < 1e-6){ 237 if(e_vec_.empty()){ 238 //std::cerr << "# [info] this equation is empty." << std::endl;
239 return 0; 240 }else
241 std::cerr << "# [error] this expression should be removed." << std::endl; 242 return __LINE__; 243 } 244
245 value() /= coeff; 246 for(typename std::list<expression<T> >::iterator lit = e_vec_.begin(); 247 lit != e_vec_.end(); ++lit){ 248 expression<T> & ep = *lit; 249 ep.coefficient /= coeff; 250 } 251 return 0; 252 } 253
254 template <typename T>
255 int equation<T>::state() const
256 { 257 if(e_vec_.empty()){ 258 if(fabs(value()) < 1e-8) 259 return 0; // is cleared
260 return -1; // is conflicted
261 }else{ 262 if(e_vec_.size() == 1) 263 return 1; // finial variant
264 else
265 return 2; // not ready
266 } 267 } 268
269 template <typename T>
270 equation<T> & equation<T>::operator -= (const equation<T> & b ) 271 { 272 if(&b == this) { 273 e_vec_.clear(); 274 value() = 0; 275 return *this; 276 } 277
278 for(typename std::list<expression<T> >::const_iterator lecit_b =
279 b.e_vec_.begin(); lecit_b != b.e_vec_.end(); ++lecit_b){ 280 const expression<T> & exp = *lecit_b; 281 const size_t &node_idx = exp.index; 282 assert(fabs(exp.coefficient) > 1e-6); 283 bool is_found = false; 284
285 for(typename std::list<expression<T> >::iterator leit_a = e_vec_.begin(); 286 leit_a != e_vec_.end(); ++leit_a){ 287 expression<T> & exp_a = *leit_a; 288 if(exp_a.index == node_idx){ 289 exp_a.coefficient -= exp.coefficient; 290 // zeros
291 if(fabs(exp_a.coefficient) < 1e-6) { 292 e_vec_.erase(leit_a); 293 is_found = true; 294 break; 295 } 296 } 297 } 298 if(!is_found){ 299 e_vec_.push_back(make_expression(node_idx, -1 * exp.coefficient)); 300 } 301 } 302
303 value() -= b.value(); 304 sort_equation(); 305 normalization(); 306
307 return *this; 308 } 309
310 template <typename T>
311 int equation<T>::update(const size_t & node_idx, const T & node_value) 312 { 313 for(typename std::list<expression<T> >::iterator leit = e_vec_.begin(); 314 leit != e_vec_.end();){ 315 expression<T> & exp = *leit; 316 if(exp.index == node_idx){ 317 value() -= exp.coefficient * node_value; 318 e_vec_.erase(leit++); 319 } 320 ++leit; 321 } 322 return 0; 323 } 324
325 //! @brief this class only handle Ai+Bi=Ci
326 template <typename T>
327 class gauss_eliminator{ 328 BOOST_MPL_ASSERT_MSG((boost::is_same<T,double>::value ) ||
329 (boost::is_same<T,float>::value ), 330 NON_FLOAT_TYPES_ARE_NOT_ALLOWED, (void)); 331 public: 332 /** 333 * @brief construct gauss_eliminator class 334 * 335 * @param nodes input nodes 336 * @param node_flag input node_flag which will be tagged as true if the 337 * corresponding node is known 338 */
339 gauss_eliminator(zjucad::matrix::matrix<T> & nodes, 340 std::vector<bool> & node_flag) 341 :nodes_(nodes), node_flag_(node_flag){ 342 idx2equation_.resize(nodes_.size()); 343 } 344
345 /** 346 * @brief add equation to gauss_eliminator, every time an equation is added, 347 * eliminate function is called. 348 * 349 * @param input equation 350 * @return int 351 */
352 int add_equation(const equation<T> & e); 353
354 /** 355 * @brief This function will start to eliminate equations above all added equations 356 * 357 * @return int return 0 if works fine, or return non-zeros 358 */
359 int eliminate(); 360
361
362 /** 363 * @brief update the equation, it will check all variant, if a variant is 364 * already known, update this equation 365 * 366 * @param eq input equation 367 * @return int return 0 if nothing changes, or retunr 1; 368 */
369 int update_equation(equation<T> & eq); 370
371 private: 372 zjucad::matrix::matrix<T> & nodes_; 373 std::vector<bool> & node_flag_; 374 std::list<equation<T> > es; 375
376 typedef typename std::list<equation<T> >::iterator equation_ptr; 377 std::vector<std::list<equation_ptr> > idx2equation_; 378
379 // this map store the smallest expression
380 typedef typename std::map<size_t, std::list<equation_ptr> >::iterator prime_eq_ptr; 381 std::map<size_t, std::list<equation_ptr> > prime_idx2equation_; 382 }; 383
384 template <typename T>
385 int gauss_eliminator<T>::add_equation(const equation<T> & e){ 386 es.push_back(e); 387 equation<T> & e_back = es.back(); 388 for(typename equation<T>::eq_iterator eit = e_back.begin(); eit != e_back.end(); ){ 389 if(node_flag_[eit->index]){ 390 e_back.value() -= nodes_[eit->index] * eit->coefficient; 391 e_back.e_vec_.erase(eit++); 392 }else
393 ++eit; 394 } 395 if(e_back.state() == 0){// this equation is cleared
396 es.pop_back(); 397 return 0; 398 }else if(e_back.state() == -1){ 399 std::cerr << "# [error] strange conflict equation: " << std::endl; 400 std::cerr << e; 401 es.pop_back(); 402 return __LINE__; 403 } 404 e_back.standardization(); 405
406 for(typename equation<T>::eq_const_iterator it = e.begin(); 407 it != e.end(); ++it){ 408 equation_ptr end_ptr = es.end(); 409 idx2equation_[it->index].push_back(--end_ptr); 410 } 411
412 equation_ptr end_ptr = es.end(); 413 prime_idx2equation_[e_back.get_prim_idx()].push_back(--end_ptr); 414 eliminate(); 415 return 0; 416 } 417
418 template <typename T>
419 int gauss_eliminator<T>::update_equation(equation<T> & eq) 420 { 421 for(typename std::list<expression<T> >::iterator it = eq.e_vec_.begin(); 422 it != eq.e_vec_.end(); ){ 423 const expression<T> & exp = *it; 424 if(node_flag_[exp.index]){ 425 eq.value() -= exp.coefficient * nodes_[exp.index]; 426 eq.e_vec_.erase(it++); 427 }else
428 ++it; 429 } 430 eq.standardization(); 431 return 0; 432 } 433
434 template <typename T>
435 int gauss_eliminator<T>::eliminate() 436 { 437 std::cerr << std::endl; 438 while(1){ 439 bool is_modified = false; 440
441 for(prime_eq_ptr ptr = prime_idx2equation_.begin(); 442 ptr != prime_idx2equation_.end();) 443 { 444 std::list<equation_ptr> & dle = ptr->second; 445 if(dle.empty()) { 446 prime_idx2equation_.erase(ptr++); 447 is_modified = true; 448 continue; 449 }else if(dle.size() == 1){ // contain only one equation
450 const int state_ = dle.front()->state(); 451 if(state_ == 0) { // cleared
452 es.erase(dle.front()); 453 continue; 454 }else if(state_ == -1){ // conflict equation
455 std::cerr << "# [error] conflict equation " << std::endl; 456 return __LINE__; 457 }else if(state_ == 1){ // finial variant
458 const equation<T> & eq = *dle.front(); 459 const T &value_ = eq.value(); 460 // prime index's coefficient should be 1
461 assert(fabs(eq.e_vec_.front().coefficient - 1) < 1e-6); 462 const size_t index = eq.get_prim_idx(); 463 if(node_flag_[index]){ 464 if(fabs(nodes_[index] - value_) > 1e-6){ 465 std::cerr << "# [error] conficts happen, node " << index 466 << " has different value " << nodes_[index] << ","
467 << value_ << std::endl; 468 return __LINE__; 469 } 470 }else{ 471 // update corresponding equations with the new node value
472 nodes_[index] = value_; 473 node_flag_[index] = true; 474 std::list<equation_ptr> & node_linked_eq = idx2equation_[index]; 475 for(typename std::list<equation_ptr>::iterator leqit =
476 node_linked_eq.begin(); leqit != node_linked_eq.end(); ++leqit){ 477 equation<T> & eq = *(*leqit); 478 eq.update(index, nodes_[index]); 479 eq.standardization(); 480 } 481 node_linked_eq.clear(); 482 prime_idx2equation_.erase(ptr); 483 is_modified = true; 484 } 485 } 486 ++ptr; 487 }else{ 488 assert(dle.size() > 1); 489 // this prime_index point to several equations, 490 // which sould be eliminated
491 typename std::list<equation_ptr>::iterator begin = dle.begin(); 492 typename std::list<equation_ptr>::iterator first = begin++; 493 // to keep each prime index linked only one equation
494 for(typename std::list<equation_ptr>::iterator next = begin; 495 next != dle.end();){ 496 // to eliminate the prime index, each equation minus the first one
497 *(*next) -= *(*first); 498 (*next)->standardization(); 499 const size_t prim_index = (*next)->get_prim_idx(); 500 assert(prim_index >= (*first)->get_prim_idx()); 501 prime_idx2equation_[prim_index].push_back(*next); 502 dle.erase(next++); 503 } 504 ++ptr; 505 } // end else
506 } 507 if(!is_modified) break; 508 } 509 return 0; 510 } 511 } 512 } 513 #endif // GAUSS_ELIMINATION_H
下面有个简单的测试例子:
//2x2+2x1+2x1=2
//x1=1
//5x1+x2=2
1 int main() 2 { typedef double val_type; 3 matrix<val_type> node = zeros<val_type>(4,1); 4 vector<bool> node_flag(4,false); 5
6 jtf::algorithm::gauss_eliminator<val_type> ge(node, node_flag); 7
8 { 9 equation<val_type> eq; 10 eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(1), 11 static_cast<val_type>(2))); 12 eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(0), 13 static_cast<val_type>(2))); 14 eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(0), 15 static_cast<val_type>(2))); 16 eq.value() = 2; 17 ge.add_equation(eq); 18 } 19 { 20 equation<val_type> eq; 21 eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(0), 22 static_cast<val_type>(1))); 23 eq.value() = 1; 24 ge.add_equation(eq); 25 } 26 { 27 equation<val_type> eq; 28 eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(0), 29 static_cast<val_type>(5))); 30
31 eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(1), 32 static_cast<val_type>(1))); 33
34 eq.value() = 2; 35 ge.add_equation(eq); 36 } 37 for(size_t t = 0; t < node_flag.size(); ++t){ 38 if(node_flag[t] == true){ 39 cerr << "# node " << t << " = " << node[t] << endl; 40 }else
41 cerr << "# node " << t << " unknown." << endl; 42 } 43 return 0; 44 }
这段code在g++4.6.3上编译通过,如果有更加方便简单的方法,请大家指教~