STL源码解析 - nth_element


STL源码解析 - nth_element

nth_element 模板函数具有两个版本

 

[cpp]  view plain copy
  1. template<class _RanIt>  
  2. void nth_element(_RanIt _First, _RanIt _Nth, _RanIt _Last);  


[cpp]  view plain copy
  1. template<class _RanIt, class _Pr>  
  2. void nth_element(_RanIt _First, _RanIt _Nth, _RanIt _Last, _Pr _Pred);  


其功能是对区间 [_First, _Last) 的元素进行重排,其中位于位置 _Nth 的元素与整个区间排序后位于位置 _Nth 的元素相同,并且满足在位置 _Nth 之前的所有元素都“不大于”它和位置 _Nth 之后的所有元素都“不小于”它,而且并不保证 _Nth 的前后两个区间的所有元素保持有序。

第一个版本,比较操作默认使用小于操作符(operator<);第二个版本,使用自定义谓词 "_Pred" 定义“小于”操作(Less Than)。

算法的空间复杂度为O(1)。

由于算法主要分两部分实现,第一部分是进行二分法弱分区,第二部分是对包含 _Nth 的位置的区间进行插入排序(STL的阈值为32)。当元素较多时平均时间复杂度为O(N),元素较少时最坏情况下时间复杂度为O(N^2)。

下面针对第一个版本的算法源代码进行注释说明,版本为 Microsoft Visual Studio 2008 SP1 安装包中的 algorithm 文件

 

[cpp]  view plain copy
  1. template<class _RanIt> inline  
  2. void nth_element(_RanIt _First, _RanIt _Nth, _RanIt _Last)  
  3. {   // order Nth element, using operator<  
  4.     _Nth_element(_CHECKED_BASE(_First), _CHECKED_BASE(_Nth), _CHECKED_BASE(_Last)); // 转调用内部实现函数  
  5. }  


_Nth_element 函数实现,其中 _ISORT_MAX 值为 32。

 

[cpp]  view plain copy
  1. template<class _RanIt> inline  
  2.     void _Nth_element(_RanIt _First, _RanIt _Nth, _RanIt _Last)  
  3.     {   // order Nth element, using operator<  
  4.     _DEBUG_RANGE(_First, _Last);  
  5.     for (; _ISORT_MAX < _Last - _First; )  
  6.         {   // divide and conquer, ordering partition containing Nth  
  7.         pair<_RanIt, _RanIt> _Mid =  
  8.             std::_Unguarded_partition(_First, _Last);  
  9.   
  10.         if (_Mid.second <= _Nth)  
  11.             _First = _Mid.second;  
  12.         else if (_Mid.first <= _Nth)  
  13.             return// Nth inside fat pivot, done  
  14.         else  
  15.             _Last = _Mid.first;  
  16.         }  
  17.   
  18.     // 插入排序  
  19.     std::_Insertion_sort(_First, _Last);    // sort any remainder  
  20.     }  

 

_Unguarded_partition 函数实现

 

[cpp]  view plain copy
  1. template<class _RanIt> inline  
  2.     pair<_RanIt, _RanIt> _Unguarded_partition(_RanIt _First, _RanIt _Last)  
  3.     {   // partition [_First, _Last), using operator<  
  4.     _RanIt _Mid = _First + (_Last - _First) / 2;    // sort median to _Mid  
  5.     std::_Median(_First, _Mid, _Last - 1);  // 端点排序  
  6.     _RanIt _Pfirst = _Mid;  
  7.     _RanIt _Plast = _Pfirst + 1;    // 起始返回区间为 [_Mid, _Mid + 1)  
  8.   
  9.     // 以下两个循环将不处理与 *_Mid 值相同的元素  
  10.     while (_First < _Pfirst  
  11.         && !_DEBUG_LT(*(_Pfirst - 1), *_Pfirst)  
  12.         && !(*_Pfirst < *(_Pfirst - 1)))  
  13.         --_Pfirst;  
  14.     while (_Plast < _Last  
  15.         && !_DEBUG_LT(*_Plast, *_Pfirst)  
  16.         && !(*_Pfirst < *_Plast))  
  17.         ++_Plast;  
  18.   
  19.     // 当前返回区间为 [_Pfirst, _Plast),且区间内值均相等  
  20.     _RanIt _Gfirst = _Plast;  
  21.     _RanIt _Glast = _Pfirst;  
  22.   
  23.     for (; ; )  
  24.         {   // partition  
  25.         // 后半区间  
  26.         for (; _Gfirst < _Last; ++_Gfirst)  
  27.             if (_DEBUG_LT(*_Pfirst, *_Gfirst))  // 大于首值,迭代器后移  
  28.                 ;  
  29.             else if (*_Gfirst < *_Pfirst)        // 小于首值,退出循环  
  30.                 break;  
  31.             else  
  32.                 std::iter_swap(_Plast++, _Gfirst);  // 与首值相等,末迭代器后移,更新末值  
  33.         // 前半区间  
  34.         for (; _First < _Glast; --_Glast)  
  35.             if (_DEBUG_LT(*(_Glast - 1), *_Pfirst)) // 小于首值,迭代器前移  
  36.                 ;  
  37.             else if (*_Pfirst < *(_Glast - 1))       // 大于首值,退出循环  
  38.                 break;  
  39.             else  
  40.                 std::iter_swap(--_Pfirst, _Glast - 1);  // 与首值相等,首迭代器前移,更新首值  
  41.   
  42.         // 整体区间已经处理结束  
  43.         if (_Glast == _First && _Gfirst == _Last)  
  44.             return (pair<_RanIt, _RanIt>(_Pfirst, _Plast));  
  45.   
  46.         // 到达起点  
  47.         if (_Glast == _First)  
  48.             {   // no room at bottom, rotate pivot upward  
  49.             if (_Plast != _Gfirst)  
  50.                 std::iter_swap(_Pfirst, _Plast);    // if 成立,_Pfirst 暂存大值  
  51.             ++_Plast;                               // 末迭代器后移  
  52.             std::iter_swap(_Pfirst++, _Gfirst++);   // if 成立时,小值将存于返回区间首,最终结果是,返回区间整体右移  
  53.             }  
  54.         else if (_Gfirst == _Last)  // 到达终点  
  55.             {   // no room at top, rotate pivot downward  
  56.             if (--_Glast != --_Pfirst)  
  57.                 std::iter_swap(_Glast, _Pfirst);    // if 成立,_Pfirst 暂存大值  
  58.             std::iter_swap(_Pfirst, --_Plast);  // if 成立时,大值将存于返回区间尾,最终结果是,返回区间整体左移  
  59.             }  
  60.         else  
  61.             std::iter_swap(_Gfirst++, --_Glast);    // 交换后,*_Glast < *_Pfirst < *(_Gfirst - 1)  
  62.         }  
  63.     }  

 

_Median 和 _Med3 两个函数,其作用是对区间内的特定几个数进行排序

 

[cpp]  view plain copy
  1. template<class _RanIt> inline  
  2.     void _Med3(_RanIt _First, _RanIt _Mid, _RanIt _Last)  
  3.     {   // sort median of three elements to middle - 3 点排序  
  4.     if (_DEBUG_LT(*_Mid, *_First))  
  5.         std::iter_swap(_Mid, _First);  
  6.     if (_DEBUG_LT(*_Last, *_Mid))  
  7.         std::iter_swap(_Last, _Mid);  
  8.     if (_DEBUG_LT(*_Mid, *_First))  
  9.         std::iter_swap(_Mid, _First);  
  10.     }  
  11.   
  12. template<class _RanIt> inline  
  13.     void _Median(_RanIt _First, _RanIt _Mid, _RanIt _Last)  
  14.     {   // sort median element to middle  
  15.     if (40 < _Last - _First)  
  16.         {   // median of nine - 9 端点排序  
  17.         size_t _Step = (_Last - _First + 1) / 8;  
  18.         std::_Med3(_First, _First + _Step, _First + 2 * _Step);  
  19.         std::_Med3(_Mid - _Step, _Mid, _Mid + _Step);  
  20.         std::_Med3(_Last - 2 * _Step, _Last - _Step, _Last);  
  21.         std::_Med3(_First + _Step, _Mid, _Last - _Step);  
  22.         }  
  23.     else  
  24.         std::_Med3(_First, _Mid, _Last);  
  25.     }  

 

对于第二个版本,算法思想相同,只是要做比较操作时,将用 _Pred 替换 operator< 操作符,同时也看到算法的核心主要在于 _Unguarded_partition 这个函数。

_Insertion_sort 函数,插入排序

 

[cpp]  view plain copy
  1. template<class _BidIt> inline  
  2.     void _Insertion_sort(_BidIt _First, _BidIt _Last)  
  3.     {   // insertion sort [_First, _Last), using operator<  
  4.     std::_Insertion_sort1(_First, _Last, _Val_type(_First)); // 转调用 _Insertion_sort1  
  5.     }  


_Insertion_sort1 函数

 

[cpp]  view plain copy
  1. template<class _BidIt,  
  2.     class _Ty> inline  
  3.     void _Insertion_sort1(_BidIt _First, _BidIt _Last, _Ty *)  
  4.     {   // insertion sort [_First, _Last), using operator<  
  5.     if (_First != _Last)  
  6.         for (_BidIt _Next = _First; ++_Next != _Last; )  
  7.             {   // order next element  
  8.             _BidIt _Next1 = _Next;  
  9.             _Ty _Val = *_Next;  
  10.   
  11.             // 小于首值时,整体后移,有可能使用 memmove,因而存在优化  
  12.             if (_DEBUG_LT(_Val, *_First))  
  13.                 {   // found new earliest element, move to front - [_First, _Next) => [..., ++Next1)  
  14.                 _STDEXT unchecked_copy_backward(_First, _Next, ++_Next1);  
  15.                 *_First = _Val;  
  16.                 }  
  17.             else  
  18.                 {   // look for insertion point after first  
  19.                 for (_BidIt _First1 = _Next1;  
  20.                     _DEBUG_LT(_Val, *--_First1);  
  21.                     _Next1 = _First1)  
  22.                     *_Next1 = *_First1; // move hole down - 逐项后移  
  23.                 *_Next1 = _Val; // insert element in hole  
  24.                 }  
  25.             }  
  26.     }  

 

至此,我们已经完全理解 nth_element 的算法思想了,并且明白为何它的时间复杂度和空间复杂度都很低,当不需要对某个数组进行全部排序而想找出满足某一条件(_Pred)的第 N 个值时,便可采用此算法,同时需要注意的是,此算法只对“随机访问迭代器”有效(如 vector),如果需要对 list 使用此算法,可先将 list 的所有元素拷贝至 vector(或者存储 list::iterator,对自定义类型效率更高),再使用此算法。



本文为senlie原创,转载请保留此地址:http://blog.csdn.net/zhengsenlie


nth_element
------------------------------------------------------------------------------


描述:重新排序,使得[nth,last)内没有任何一个元素小于[first,nth)内的元素,
但对于[first,nth)和[nth,last)两个子区间内的元素次序则无任何保证。
思路:
1.以 median-of-3-partition 将整个序列分割为更小的左、右子序列
2.如果 nth 迭代器落于左序列,就再对左子序列进行分割,否则就再对右子序列进行分割
3.直到分割后的子序列长大于3,对最后这个待分割的子序列做 Insertion Sort
STL源码解析 - nth_element_第1张图片
复杂度:O(n)
源码:

[cpp]  view plain copy
  1. template <class RandomAccessIterator>  
  2. inline void nth_element(RandomAccessIterator first, RandomAccessIterator nth,  
  3.                         RandomAccessIterator last) {  
  4.   __nth_element(first, nth, last, value_type(first));  
  5. }  
  6.   
  7.   
  8. template <class RandomAccessIterator, class T>  
  9. void __nth_element(RandomAccessIterator first, RandomAccessIterator nth,  
  10.                    RandomAccessIterator last, T*) {  
  11.   while (last - first > 3) {  
  12.     //采用 median-of-3-partition 。参数:(first,last,pivot)  
  13.     //返回一个迭代器,指向分割后的右段第一个元素  
  14.     RandomAccessIterator cut = __unguarded_partition  
  15.       (first, last, T(__median(*first, *(first + (last - first)/2),  
  16.                                *(last - 1))));  
  17.     if (cut <= nth) //如果  nth 落于右段,再对右段实施分割  
  18.       first = cut;  
  19.     else  //如果 nth 落于左段,对左段实施分割  
  20.       last = cut;  
  21.   }  
  22.   __insertion_sort(first, last); //对分割后的子序列做 Insertion Sort  
  23. }  
  24.   
  25.   
  26. template <class RandomAccessIterator, class T>  
  27. RandomAccessIterator __unguarded_partition(RandomAccessIterator first,   
  28.                                            RandomAccessIterator last,   
  29.                                            T pivot) {  
  30.   while (true) {  
  31.     while (*first < pivot) ++first;  
  32.     --last;  
  33.     while (pivot < *last) --last;  
  34.     if (!(first < last)) return first;  
  35.     iter_swap(first, last);  
  36.     ++first;  
  37.   }  
  38. }      

示例:
[cpp]  view plain copy
  1. int A[] = {7, 2, 6, 11, 9, 3, 12, 10, 8, 4, 1, 5};  
  2. const int N = sizeof(A) / sizeof(int);  
  3.   
  4.   
  5. nth_element(A, A + 6, A + N);  
  6. copy(A, A + N, ostream_iterator<int>(cout, " "));  
  7. // The printed result is "5 2 6 1 4 3 7 8 9 10 11 12".  


你可能感兴趣的:(算法和数据结构)