高斯消元法的简单实现

项目中需要求解动态增长的线性方程组,方程个数会不断增加,需要使用高斯消元法来求解,每个方程添加完后,已知的变量值。

代码中分了三个类,表达式expression,方程 equation, 高斯消元求解器gauss_eliminator。

这段代码还是有问题,如果需要最新的,请访问我的github  https://github.com/tfjiang/gauss_elimination

 1 #ifndef GAUSS_ELIMINATION_H  2 #define GAUSS_ELIMINATION_H
 3 
 4 #include <map>
 5 #include <vector>
 6 #include <list>
 7 #include <deque>
 9 #include <iostream>
 10 #include <boost/unordered_map.hpp>
 11 #include <boost/static_assert.hpp>
 12 #include <boost/type_traits/is_same.hpp>
 13 #include <zjucad/matrix/matrix.h>  // a private matrix library
 14 
 15 namespace jtf{  16 namespace algorithm{  17 
 18 /**  19  * @brief store an expression A * i, where A is the coeffieient  20  *  21  */
 22 template <typename T>
 23 class expression{  24 public:  25   expression():index(-1),coefficient(0){}  26   expression(const size_t & index_, const T & coefficient_)  27  : index(index_), coefficient(coefficient_){}  28  size_t index;  29  T coefficient;  30   /**  31  * @brief operator < is used to determin which expression ha  32  *  33  * @param other input other expression  34  * @return bool  35    */
 36   bool operator < (const expression<T> & b) const{  37     return index < b.index;  38  }  39 
 40   /**  41  * @brief To determin whether coefficent of this expression is zeros  42  *  43  * @return bool  44    */
 45   bool is_zero() const{ return fabs(coefficient) < 1e-6;}  46 };  47 
 48 /**  49  * @brief make expression with node_idx and value  50  *  51  * @param node_idx  52  * @param value  53  * @return expression<T>  54  */
 55 template <typename T>
 56 expression<T> make_expression(const size_t & node_idx, const T & value)  57 {  58   expression<T> temp(node_idx, value);  59   return temp;  60 }  61 
 62 /**  63  * @brief equation class  64  *  65  */
 66 template <typename T>
 67 class equation{  68 public:  69   typedef typename std::list<expression<T> >::const_iterator eq_const_iterator;  70   typedef typename std::list<expression<T> >::iterator eq_iterator;  71 
 72   equation():value_(static_cast<T>(0)){}  73 
 74   eq_const_iterator begin() const {return e_vec_.begin();}  75   eq_iterator begin(){return e_vec_.begin();}  76 
 77   eq_const_iterator end() const {return e_vec_.end();}  78   eq_iterator end(){return e_vec_.end();}  79 
 80   /**  81  * @brief used to standardizate the equation, sort the expressions and merge  82  * similar items and normalization to make the first item's coefficient  83  * equal 1  84  *  85  * @return int  86    */
 87   int standardization(){  88  sort_equation();  89  merge_similar();  90  normalization();  91     return 0;  92  }  93 
 94   /**  95  * @brief update the equation with given node idx and value  96  *  97  * @param node_idx input node idx  98  * @param node_value input node value  99  * @return int 100    */
101   int update(const size_t & node_idx, const  T & node_value); 102   /** 103  * @brief sort equation accronding to the expressions 104  * 105  * @return int 106    */
107   int sort_equation(){ 108  e_vec_.sort(); 109     return 0; 110  } 111 
112   /** 113  * @brief merge similar items, WARNING. this function should be used after 114  * sorting. 115  * 116  * @return int 117    */
118   int merge_similar(); 119 
120   /** 121  * @brief normalize the equations: to make the first expression's coefficient 122  * equals 1 123  * 124  * @return int 125    */
126   int normalization(); 127 
128   /** 129  * @brief get the equation value 130  * 131  * @return const T 132    */
133   const T& get_value() const {return value_;} 134 
135   /** 136  * @brief get the state of equation: 137  * if there are no expressions, 138  * if value != 0 return -1; error equation 139  * else return 0; cleared 140  * else 141  * if there is one expression return 1; calculated finial node 142  * else return 2; not finished 143  * 144  * @return int 145    */
146   int state() const; 147 
148 
149   /** 150  * @brief define the minus operation 151  * 152  * @param b 153  * @return equation<T> 154    */
155   equation<T> & operator -= (const equation<T> & b); 156 
157   /** 158  * @brief define the output operation 159  * 160  * @param eq 161  * @return equation<T> 162      */
163   friend std::ostream& operator << (std::ostream &output, 164                              const equation<T> &eq) 165  { 166     if(eq.e_vec_.empty()){ 167       output << "# ------ empty expressions with value = " << eq.value() << std::endl; 168     }else{ 169       output << "# ------ SUM coeff * index , val" << std::endl 170            << "# ------ "; 171       for(equation<T>::eq_const_iterator eqcit = eq.begin(); eqcit != eq.end();){ 172         const expression<T> & exp = *eqcit; 173         output << exp.coefficient << "*X" << exp.index; 174         ++eqcit; 175         if(eqcit == eq.end()) 176           output << " = "; 177         else
178           output << " + "; 179  } 180       output << eq.value() << std::endl; 181  } 182     return output; 183  } 184 
185 
186   /** 187  * @brief get the first expression idx 188  * 189  * @return size_t 190    */
191   size_t get_prim_idx() const { 192     return e_vec_.front().index; 193  } 194 
195   int add_expression(const expression<T> & exp){ 196  e_vec_.push_back(exp); 197     return 0; 198  } 199   T& value() {return value_;} 200   const T& value() const {return value_;} 201   std::list<expression<T> > e_vec_; 202 private: 203  T value_; 204 }; 205 
206 template <typename T>
207 int equation<T>::merge_similar(){ 208   typedef typename std::list<expression<T> >::iterator leit; 209   leit current = e_vec_.begin(); 210   leit next = current; 211   ++next; 212   while(next != e_vec_.end()){ 213     if(next->index > current->index) { 214       current = next++; 215       continue; 216     }else if(next->index == current->index){ 217       current->coefficient += next->coefficient; 218       e_vec_.erase(next++); 219       if(fabs(current->coefficient) < 1e-6){ 220         e_vec_.erase(current++); 221         ++next; 222  } 223     }else{ 224       std::cerr << "# [error] merge similar function should only be called "
225                 << "after sorting." << std::endl; 226       return __LINE__; 227  } 228  } 229   return 0; 230 } 231 
232 template <typename T>
233 int equation<T>::normalization() 234 { 235   const T coeff = e_vec_.front().coefficient; 236   if(fabs(coeff) < 1e-6){ 237     if(e_vec_.empty()){ 238       //std::cerr << "# [info] this equation is empty." << std::endl;
239       return 0; 240     }else
241       std::cerr << "# [error] this expression should be removed." << std::endl; 242     return __LINE__; 243  } 244 
245   value() /= coeff; 246   for(typename std::list<expression<T> >::iterator lit = e_vec_.begin(); 247       lit != e_vec_.end(); ++lit){ 248     expression<T> & ep = *lit; 249     ep.coefficient /= coeff; 250  } 251   return 0; 252 } 253 
254 template <typename T>
255 int equation<T>::state() const
256 { 257   if(e_vec_.empty()){ 258     if(fabs(value()) < 1e-8) 259       return 0; // is cleared
260     return -1; // is conflicted
261   }else{ 262     if(e_vec_.size() == 1) 263       return 1; // finial variant
264     else
265       return 2; // not ready
266  } 267 } 268 
269 template <typename T>
270 equation<T> & equation<T>::operator -= (const equation<T> & b ) 271 { 272   if(&b == this) { 273  e_vec_.clear(); 274     value() = 0; 275     return *this; 276  } 277 
278   for(typename std::list<expression<T> >::const_iterator lecit_b =
279       b.e_vec_.begin(); lecit_b != b.e_vec_.end(); ++lecit_b){ 280     const expression<T> & exp = *lecit_b; 281     const size_t &node_idx = exp.index; 282     assert(fabs(exp.coefficient) > 1e-6); 283     bool is_found = false; 284 
285     for(typename std::list<expression<T> >::iterator leit_a = e_vec_.begin(); 286         leit_a != e_vec_.end(); ++leit_a){ 287       expression<T> & exp_a = *leit_a; 288       if(exp_a.index == node_idx){ 289         exp_a.coefficient -= exp.coefficient; 290         // zeros
291         if(fabs(exp_a.coefficient) < 1e-6) { 292  e_vec_.erase(leit_a); 293           is_found = true; 294           break; 295  } 296  } 297  } 298     if(!is_found){ 299       e_vec_.push_back(make_expression(node_idx, -1 * exp.coefficient)); 300  } 301  } 302 
303   value() -= b.value(); 304  sort_equation(); 305  normalization(); 306 
307   return *this; 308 } 309 
310 template <typename T>
311 int equation<T>::update(const size_t & node_idx, const  T & node_value) 312 { 313   for(typename std::list<expression<T> >::iterator leit = e_vec_.begin(); 314       leit != e_vec_.end();){ 315     expression<T> & exp = *leit; 316     if(exp.index == node_idx){ 317       value() -= exp.coefficient * node_value; 318       e_vec_.erase(leit++); 319  } 320     ++leit; 321  } 322   return 0; 323 } 324 
325 //! @brief this class only handle Ai+Bi=Ci
326 template <typename T>
327 class gauss_eliminator{ 328   BOOST_MPL_ASSERT_MSG((boost::is_same<T,double>::value ) ||
329                        (boost::is_same<T,float>::value ), 330                        NON_FLOAT_TYPES_ARE_NOT_ALLOWED, (void)); 331 public: 332   /** 333  * @brief construct gauss_eliminator class 334  * 335  * @param nodes input nodes 336  * @param node_flag input node_flag which will be tagged as true if the 337  * corresponding node is known 338  */
339   gauss_eliminator(zjucad::matrix::matrix<T> & nodes, 340                    std::vector<bool> & node_flag) 341  :nodes_(nodes), node_flag_(node_flag){ 342  idx2equation_.resize(nodes_.size()); 343  } 344 
345   /** 346  * @brief add equation to gauss_eliminator, every time an equation is added, 347  * eliminate function is called. 348  * 349  * @param input equation 350  * @return int 351    */
352   int add_equation(const equation<T> & e); 353 
354   /** 355  * @brief This function will start to eliminate equations above all added equations 356  * 357  * @return int return 0 if works fine, or return non-zeros 358    */
359   int eliminate(); 360 
361 
362   /** 363  * @brief update the equation, it will check all variant, if a variant is 364  * already known, update this equation 365  * 366  * @param eq input equation 367  * @return int return 0 if nothing changes, or retunr 1; 368    */
369   int update_equation(equation<T> & eq); 370 
371 private: 372   zjucad::matrix::matrix<T> & nodes_; 373   std::vector<bool> & node_flag_; 374   std::list<equation<T> > es; 375 
376   typedef typename std::list<equation<T> >::iterator equation_ptr; 377   std::vector<std::list<equation_ptr> > idx2equation_; 378 
379   // this map store the smallest expression
380   typedef typename std::map<size_t, std::list<equation_ptr> >::iterator prime_eq_ptr; 381   std::map<size_t, std::list<equation_ptr> > prime_idx2equation_; 382 }; 383 
384 template <typename T>
385 int gauss_eliminator<T>::add_equation(const equation<T> & e){ 386  es.push_back(e); 387   equation<T> & e_back = es.back(); 388   for(typename equation<T>::eq_iterator eit = e_back.begin(); eit != e_back.end(); ){ 389     if(node_flag_[eit->index]){ 390       e_back.value() -= nodes_[eit->index] * eit->coefficient; 391       e_back.e_vec_.erase(eit++); 392     }else
393       ++eit; 394  } 395   if(e_back.state() == 0){// this equation is cleared
396  es.pop_back(); 397     return 0; 398   }else if(e_back.state() == -1){ 399     std::cerr << "# [error] strange conflict equation: " << std::endl; 400     std::cerr << e; 401  es.pop_back(); 402     return __LINE__; 403  } 404  e_back.standardization(); 405 
406   for(typename equation<T>::eq_const_iterator it = e.begin(); 407       it != e.end(); ++it){ 408     equation_ptr end_ptr = es.end(); 409     idx2equation_[it->index].push_back(--end_ptr); 410  } 411 
412   equation_ptr end_ptr = es.end(); 413   prime_idx2equation_[e_back.get_prim_idx()].push_back(--end_ptr); 414  eliminate(); 415   return 0; 416 } 417 
418 template <typename T>
419 int gauss_eliminator<T>::update_equation(equation<T> & eq) 420 { 421   for(typename std::list<expression<T> >::iterator it = eq.e_vec_.begin(); 422       it != eq.e_vec_.end(); ){ 423     const expression<T> & exp = *it; 424     if(node_flag_[exp.index]){ 425       eq.value() -= exp.coefficient * nodes_[exp.index]; 426       eq.e_vec_.erase(it++); 427     }else
428       ++it; 429  } 430  eq.standardization(); 431   return 0; 432 } 433 
434 template <typename T>
435 int gauss_eliminator<T>::eliminate() 436 { 437   std::cerr << std::endl; 438   while(1){ 439     bool is_modified = false; 440 
441     for(prime_eq_ptr ptr = prime_idx2equation_.begin(); 442         ptr != prime_idx2equation_.end();) 443  { 444       std::list<equation_ptr> & dle = ptr->second; 445       if(dle.empty()) { 446         prime_idx2equation_.erase(ptr++); 447         is_modified = true; 448         continue; 449       }else if(dle.size() == 1){ // contain only one equation
450         const int state_ = dle.front()->state(); 451         if(state_ == 0) { // cleared
452  es.erase(dle.front()); 453           continue; 454         }else if(state_ == -1){ // conflict equation
455           std::cerr << "# [error] conflict equation " << std::endl; 456           return __LINE__; 457         }else if(state_ == 1){ // finial variant
458           const equation<T> & eq = *dle.front(); 459           const T &value_ = eq.value(); 460           // prime index's coefficient should be 1
461           assert(fabs(eq.e_vec_.front().coefficient - 1) < 1e-6); 462           const size_t index = eq.get_prim_idx(); 463           if(node_flag_[index]){ 464             if(fabs(nodes_[index] - value_) > 1e-6){ 465               std::cerr << "# [error] conficts happen, node " << index 466                         << " has different value " << nodes_[index] << ","
467                         << value_ << std::endl; 468               return __LINE__; 469  } 470           }else{ 471             // update corresponding equations with the new node value
472             nodes_[index] = value_; 473             node_flag_[index] = true; 474             std::list<equation_ptr> & node_linked_eq = idx2equation_[index]; 475             for(typename std::list<equation_ptr>::iterator leqit =
476                 node_linked_eq.begin(); leqit != node_linked_eq.end(); ++leqit){ 477               equation<T> & eq = *(*leqit); 478  eq.update(index, nodes_[index]); 479  eq.standardization(); 480  } 481  node_linked_eq.clear(); 482  prime_idx2equation_.erase(ptr); 483             is_modified = true; 484  } 485  } 486         ++ptr; 487       }else{ 488         assert(dle.size() > 1); 489         // this prime_index point to several equations, 490         // which sould be eliminated
491         typename std::list<equation_ptr>::iterator begin = dle.begin(); 492         typename std::list<equation_ptr>::iterator first = begin++; 493         // to keep each prime index linked only one equation
494         for(typename std::list<equation_ptr>::iterator next = begin; 495             next != dle.end();){ 496           // to eliminate the prime index, each equation minus the first one
497           *(*next) -= *(*first); 498           (*next)->standardization(); 499           const size_t prim_index = (*next)->get_prim_idx(); 500           assert(prim_index >= (*first)->get_prim_idx()); 501           prime_idx2equation_[prim_index].push_back(*next); 502           dle.erase(next++); 503  } 504         ++ptr; 505       } // end else
506  } 507     if(!is_modified) break; 508  } 509   return 0; 510 } 511 } 512 } 513 #endif // GAUSS_ELIMINATION_H

 

下面有个简单的测试例子:


  //2x2+2x1+2x1=2
//x1=1
//5x1+x2=2
  
 1  int main()  2 { typedef double val_type;  3   matrix<val_type> node = zeros<val_type>(4,1);  4   vector<bool> node_flag(4,false);  5 
 6   jtf::algorithm::gauss_eliminator<val_type> ge(node, node_flag);  7 
 8  {  9     equation<val_type> eq; 10     eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(1), 11                                                       static_cast<val_type>(2))); 12     eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(0), 13                                                       static_cast<val_type>(2))); 14     eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(0), 15                                                       static_cast<val_type>(2))); 16     eq.value() = 2; 17  ge.add_equation(eq); 18  } 19  { 20     equation<val_type> eq; 21     eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(0), 22                                                       static_cast<val_type>(1))); 23     eq.value() = 1; 24  ge.add_equation(eq); 25  } 26  { 27     equation<val_type> eq; 28     eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(0), 29                                                       static_cast<val_type>(5))); 30 
31     eq.add_expression(jtf::algorithm::make_expression(static_cast<size_t>(1), 32                                                       static_cast<val_type>(1))); 33 
34     eq.value() = 2; 35  ge.add_equation(eq); 36  } 37   for(size_t t = 0; t < node_flag.size(); ++t){ 38     if(node_flag[t] == true){ 39       cerr << "# node " << t << " = " << node[t] << endl; 40     }else
41       cerr << "# node " << t << " unknown." << endl; 42  } 43 return 0; 44 }

这段code在g++4.6.3上编译通过,如果有更加方便简单的方法,请大家指教~

你可能感兴趣的:(实现)