Select 问题: 在一个无序的数组中 找到第 n 大的元素。
思路 1: 排序,O(NlgN)
思路 2: 利用快排的 RandomizedPartition(), 平均复杂度是 O(N)
思路 3: 同样是利用快排的 Partition(), 但是选择 pivot 的时候不是采用随机,而是通过一种特殊的方法。从而使复杂度最坏情况下是 O(N)。
本文介绍 STL 算法库中 nth_elemnt 的实现代码。
STL 采用的算法是: 当数组长度 <= 3时, 采用插入排序。
当长度 > 3时, 采用快排 Partition 的思想;
一、使用说明
void
nth_element (RandomAccessIteratorbeg,
RandomAccessIterator nth,
RandomAccessIterator end)
void
nth_element (RandomAccessIterator beg,
RandomAccessIterator nth,
RandomAccessIterator end,
BinaryPredicate op)
1. 两个函数都是让 第 n 个位置上的元素就位,
所有在位置 n 之前的元素都小于或等于它,
所有在位置 n 之后的元素都大于或等于它。
2. 复杂度: 平均复杂度是 O(N)
以下例子是使用范例:
// copyright @ L.J.SHOU Feb.23, 2014 #include <iostream> #include <algorithm> #include <iterator> using namespace std; int main(void) { int a[]={3,5,2,6,1,4}; nth_element(a, a+3, a+sizeof(a)/sizeof(int)); cout << "The fourth element is: " << a[3] << endl; // output array a[] copy(a, a+sizeof(a)/sizeof(int), ostream_iterator<int>(cout, " ")); return 0; }
程序输出结果:
The fourth element is: 4二、源码分析
// nth_element() and its auxiliary functions. template <class _RandomAccessIter, class _Tp> void __nth_element(_RandomAccessIter __first, _RandomAccessIter __nth, _RandomAccessIter __last, _Tp*) { while (__last - __first > 3) { _RandomAccessIter __cut = __unguarded_partition(__first, __last, _Tp(__median(*__first, *(__first + (__last - __first)/2), *(__last - 1)))); if (__cut <= __nth) __first = __cut; else __last = __cut; } __insertion_sort(__first, __last); } template <class _RandomAccessIter> inline void nth_element(_RandomAccessIter __first, _RandomAccessIter __nth, _RandomAccessIter __last) { __STL_REQUIRES(_RandomAccessIter, _Mutable_RandomAccessIterator); __STL_REQUIRES(typename iterator_traits<_RandomAccessIter>::value_type, _LessThanComparable); __nth_element(__first, __nth, __last, __VALUE_TYPE(__first)); }
template <class _RandomAccessIter, class _Tp> _RandomAccessIter __unguarded_partition(_RandomAccessIter __first, _RandomAccessIter __last, _Tp __pivot) { while (true) { while (*__first < __pivot) ++__first; --__last; while (__pivot < *__last) --__last; if (!(__first < __last)) return __first; iter_swap(__first, __last); ++__first; } }
_unguarded_partition 就是快排的 partition, 将数组分成两部分,左边的元素都小于或者等于 pivot, 右边的元素都大于或者等于 pivot.
从上述代码可以看出, nth_element 采用的 pivot 是 首元素,尾元素,中间元素,三个数的median.
通过_unguarded_partition 将数组分成两部分,
如果 nth 这个迭代器在左半边,则继续在左半边搜索;
若 nth 在右半边, 则在右半边搜索;
直到数组的长度 <= 3,时, 采用插入排序。这时 nth 迭代器所指向的数就归位了,而且它的左边元素都小于或者等于它, 右边元素都大于或者等于它。