Python
实现import random
def partition(nums, low, high):
pivot_index = random.randint(low, high)
pivot = nums[pivot_index]
# 将pivot元素移动到列表的最右边
nums[pivot_index], nums[high] = nums[high], nums[pivot_index]
# 通过交换操作, 将小于 pivot 的元素移动到左边, 大于 pivot 的元素移动到右边
i = low
for j in range(low, high):
if nums[j] < pivot:
nums[i], nums[j] = nums[j], nums[i]
i += 1
# 将 pivot 元素放置到正确的位置
nums[i], nums[high] = nums[high], nums[i]
return i
def quick_select(nums, low, high, k):
if low == high:
return nums[low]
# 划分数组, 并获取 pivot 元素的索引
pivot_index = partition(nums, low, high)
j = pivot_index - low + 1
# 如果 pivot 元素的索引等于 k, 则返回该元素
if j == k:
return nums[pivot_index]
# 如果 pivot 元素的索引大于 k, 则在左侧继续查找
elif j > k:
return quick_select(nums, low, pivot_index - 1, k)
# 如果 pivot 元素的索引小于 k, 则在右侧继续查找
else:
return quick_select(nums, pivot_index + 1, high, k - j)
def find_kth_smallest(nums, k):
if k < 1 or k > len(nums):
raise ValueError('Invalid value of k')
return quick_select(nums, 0, len(nums) - 1, k)
nums = [3, 1, 5, 2, 4]
k = 2
res = find_kth_smallest(nums, k)
print(f'第 {k} 小的元素为 {res}')
如果能在线性时间内找到一个划分基准,使得按这个基准划分出的两个子数组的长度都至少为原数组长度的 ε \varepsilon ε倍( 0 < ε < 1 0< \varepsilon < 1 0<ε<1是某个常数),那么在最坏情况下用 O ( n ) O(n) O(n)时间就可以完成选择任务
将 n n n个输入元素划分成 ⌈ n / 5 ⌉ \left\lceil n / 5 \right\rceil ⌈n/5⌉个组,每组 5 5 5个元素(除可能有一个组不是 5 5 5个元素外),用任意一种排序算法,将每组中的元素排好序,并取出每组的中位数,共 ⌈ n / 5 ⌉ \left\lceil n / 5 \right\rceil ⌈n/5⌉个
递归调用找出这 ⌈ n / 5 ⌉ \left\lceil n / 5 \right\rceil ⌈n/5⌉个元素的中位数,如果 ⌈ n / 5 ⌉ \left\lceil n / 5 \right\rceil ⌈n/5⌉是偶数,就找它的两个中位数中较大的一个,然后以这个元素作为划分基准
设所有元素互不相同,找出的基准 x x x至少比 3 ⌊ ( n − 5 ) / 10 ⌋ 3 \left\lfloor (n - 5) / 10 \right\rfloor 3⌊(n−5)/10⌋个元素大,至少比 3 ⌊ ( n − 5 ) / 10 ⌋ 3 \left\lfloor (n - 5) / 10 \right\rfloor 3⌊(n−5)/10⌋个元素小,当 n ≥ 75 n \geq 75 n≥75时, 3 ⌊ ( n − 5 ) / 10 ⌋ ≥ n / 4 3 \left\lfloor (n - 5) / 10 \right\rfloor \geq n / 4 3⌊(n−5)/10⌋≥n/4,所以按此基准划分所得的两个子数组的长度都至少缩短 1 / 4 1 / 4 1/4
T ( n ) ≤ { C 1 , n < 75 C 2 n + T ( n / 5 ) + T ( 3 n / 4 ) , n ≥ 75 T(n) \leq \begin{cases} C_{1} , & n < 75 \\ C_{2} n + T(n / 5) + T(3n / 4) , & n \geq 75 \end{cases} T(n)≤{C1,C2n+T(n/5)+T(3n/4),n<75n≥75
T ( n ) = O ( n ) T(n) = O(n) T(n)=O(n)
Python
实现import statistics
def find_median_of_medians(arr):
# 将数组划分为大小为 5 的子数组
subarrays = [arr[i:i + 5] for i in range(0, len(arr), 5)]
# 计算每个子数组的中位数
medians = [statistics.median(subarray) for subarray in subarrays]
# 如果元素数量小于等于 5, 直接返回中位数
if len(medians) <= 5:
return statistics.median(medians)
# 递归调用中位数的中位数算法
return find_median_of_medians(medians)
def linear_time_select(arr, k):
# 找到中位数的中位数
median_of_medians = find_median_of_medians(arr)
# 将数组划分为三个部分
less = [x for x in arr if x < median_of_medians]
equal = [x for x in arr if x == median_of_medians]
greater = [x for x in arr if x > median_of_medians]
# 根据划分后的数组长度选择下一步操作
if k <= len(less):
# 在较小的部分递归查找第 k 小元素
return linear_time_select(less, k)
elif k <= len(less) + len(equal):
# 第 k 小元素等于中位数的中位数
return median_of_medians
else:
# 在较大的部分递归查找第 k 小元素
return linear_time_select(greater, k - len(less) - len(equal))
arr = [3, 1, 5, 2, 4, 9, 7, 8, 6]
k = 5
res = linear_time_select(arr, k)
print(f'第 {k} 小的元素为 {res}')