通过/超时:在牛客网试题【排序】中的提交结果。
算法复杂度分析
算法的时间复杂度和空间复杂度详解
最详细的解说—时间和空间复杂度
时间复杂度反映了算法运行时间随输入规模增长而增长的量级,记为T(N) = O( g(N) )
。
T(N)
:算法运行的时间频度,即基本语句(例如:比较)的执行次数,可以用于衡量算法的精确运行时间,一般表现为多项式(例如:a * N^2 + b * N + c
)。当输入规模N
足够大时,精确运行时间中的倍增常量和低阶项被输入规模本身的影响所支配,所以我们只需关注运行时间的增长量级(例如:N^2
),即研究算法的渐近效率,使用以下3种渐近记号来表示,详见下图(图片来自《算法导论》第 3 章 函数的增长)。
Θ( g(N) )
:T(N)
函数的渐近紧确界,即同时满足上界O( g(N) )
和下界Ω( g(N) )
。
O( g(N) )
:T(N)
函数的渐近上界,例如O( N^2 )
可以理解为当N
足够大时,算法的运行时间频度总是小于等于N^2
的常数倍,常用于表示最差/平均时间复杂度。
Ω( g(N) )
:T(N)
函数的渐近下界,例如Ω( N^2 )
可以理解为当N
足够大时,算法的运行时间频度总是大于等于N^2
的常数倍,常用于表示最优时间复杂度。
《算法导论》在第二部分(排序和顺序统计量)的前言中,对常见排序算法的时间复杂度总结如下:
根据《算法导论》第二部分(排序和顺序统计量)的前言,如果输入数组中仅有常数个元素需要在排序过程中存储在数组之外,则称排序算法是原址的(in place)。
根据《算法导论》8.2 章节,排序算法是的稳定的:具有相同值的元素在输出数组中的相对次序与它们在输入数组中的相对次序相同。当所排序的数据还附带其他对象信息时,算法的稳定性就比较重要。
根据《算法导论》8.1 章节,在最坏情况下:
Ω(N * lgN)
次比较(下界)O(N * lgN)
O(N * lgN)
。根据《算法导论》7.1 章节,虽然快速排序的最差时间复杂度很差(O(N^2)
),但它的平均性能非常好:
O(N * lgN)
中隐含的常数因子非常小,低于归并排序和堆排序平均时间复杂度:O(N^2)
最差时间复杂度:O(N^2)
最优时间复杂度:O(N)
需要加上两个优化项
原址性:原址,只有交换元素时需要在原数组之外临时存储一个元素。
稳定性:稳定,在相邻元素相等时,它们并不会交换位置。
class Solution:
def MySort(self, arr):
n = len(arr)
for i in range(n-1, -1, -1):
for j in range(1, i+1):
if arr[j-1] > arr[j]:
arr[j-1], arr[j] = arr[j], arr[j-1]
return arr
class Solution:
def MySort(self, arr):
n = len(arr)
for i in range(n):
for j in range(1, n-i):
if arr[j-1] > arr[j]:
arr[j-1], arr[j] = arr[j], arr[j-1]
return arr
从上述两个版本可以看出,外循环的正/倒序无所谓,而内循环的正/倒序取决于冒最大值,还是冒最小值。
class Solution:
def MySort(self, arr):
n = len(arr)
for i in range(n):
for j in range(n-1, i, -1):
if arr[j-1] > arr[j]:
arr[j-1], arr[j] = arr[j], arr[j-1]
return arr
【优化一】缩小外循环:如果某一轮没有任何可交换的元素,说明已排序完成,不用再进行下一轮冒泡了,减少了总冒泡轮数 i,需要用 swap_flag 标记本轮是否发生过交换。
class Solution:
def MySort(self, arr):
"""【优化一】冒最大值,缩小外循环"""
n = len(arr)
for i in range(n):
swap_flag = False
for j in range(1, n-i):
if arr[j-1] > arr[j]:
arr[j-1], arr[j] = arr[j], arr[j-1]
swap_flag = True
if not swap_flag:
break
return arr
【优化二】缩小内循环:记录本轮最后发生交换的位置(last_swap_pos),此位置之后已排好序,下轮到此位置即可,减少了下轮要检查的元素数量。
class Solution:
def MySort(self, arr):
"""【优化二】冒最大值,缩小内循环"""
n = len(arr)
last_swap_pos = n
for i in range(n):
for j in range(1, last_swap_pos):
if arr[j-1] > arr[j]:
arr[j-1], arr[j] = arr[j], arr[j-1]
last_swap_pos = j
return arr
两个优化项同时使用
class Solution:
def MySort(self, arr):
n = len(arr)
last_swap_pos = n
for i in range(n):
swap_flag = False
for j in range(1, last_swap_pos):
if arr[j-1] > arr[j]:
arr[j-1], arr[j] = arr[j], arr[j-1]
swap_flag = True
last_swap_pos = j
if not swap_flag:
break
return arr
只需把比较符号换成 >
class Solution:
def MySort(self, arr):
"""降序"""
n = len(arr)
last_swap_pos = n
for i in range(n-1, -1, -1):
swap_flag = False
for j in range(1, last_swap_pos):
if arr[j-1] < arr[j]:
arr[j-1], arr[j] = arr[j], arr[j-1]
swap_flag = True
last_swap_pos = j
if not swap_flag:
break
return arr
冒泡排序的变形版,也称搅拌排序。
class Solution:
def MySort(self, arr):
n = len(arr)
for i in range(n-1, -1, -1):
for j in range(1, i+1):
if arr[j] < arr[j-1]:
arr[j], arr[j-1] = arr[j-1], arr[j]
for j in range(i-1, 0, -1):
if arr[j] < arr[j-1]:
arr[j], arr[j-1] = arr[j-1], arr[j]
return arr
平均时间复杂度:O(N * lgN)
,只要分区比例是常数都会产生深度为O(lgN)
的递归树,而每一层的时间代价都是O(N)
,详见《算法导论》章节 7.2 快速排序的性能。
最差时间复杂度:O(n^2)
,分区极不平衡时T(N) = 分区操作 O(N) + 左子区递归 T(N-1) + 右子区递归 T(0)
最优时间复杂度:O(N * lgN)
,分区最平衡时T(N) = 分区操作O(N) + 左子区递归 T(N/2) + 右子区递归 T(N/2)
原址性:原址,只有交换元素时需要在原数组之外临时存储一个元素。
稳定性:快速排序本身是不稳定的,如果修改代码使其稳定,则一定会牺牲性能(即使能保持时间复杂度仍为O(N * lgN)
,但其中隐含的常数因子会增大)。
递归函数
分排函数(左右指针交换版)
不稳定原因
class Solution:
def MySort(self, arr):
"""递归版"""
self.quick_sort_recur(arr, 0, len(arr) - 1)
return arr
def quick_sort_recur(self, arr, start, end):
if start < end:
mid = self.partition(arr, start, end)
self.quick_sort_recur(arr, start, mid - 1)
self.quick_sort_recur(arr, mid + 1, end)
def partition(self, arr, start, end):
"""分排函数【左右指针交换版 + 以start为基准】"""
left = start
right = end
while left < right:
while left < right and arr[right] >= arr[start]:
right -= 1
while left < right and arr[left] <= arr[start]:
left += 1
arr[left], arr[right] = arr[right], arr[left]
arr[start], arr[left] = arr[left], arr[start]
return left
注意
import random
pivot = random.randint(start, end)
arr[pivot], arr[start] = arr[start], arr[pivot]
class Solution:
def MySort(self, arr):
"""递归版"""
self.quick_sort_recur(arr, 0, len(arr) - 1)
return arr
def quick_sort_recur(self, arr, start, end):
if start < end:
# 把数组分成 3 部分:小于基准 | 等于基准 | 大于基准,只需对小值区、大值区进行递归,重复元素多时改善效率
lt, gt = self.partition(arr, start, end)
self.quick_sort_recur(arr, start, lt)
self.quick_sort_recur(arr, gt, end)
def partition(self, arr, start, end):
"""分排函数【左右指针三路版】"""
import random
idx = random.randint(start, end)
arr[idx], arr[end] = arr[end], arr[idx]
pivot = arr[end]
lt = start - 1
gt = end
i = start
while i < gt:
# 遇到等值,扩大等值区
if arr[i] == pivot:
i += 1
# 遇到小值,把小值和等值左边界进行交换,扩大小值区,继续检查下个元素
elif arr[i] < pivot:
lt += 1
arr[lt], arr[i] = arr[i], arr[lt]
i += 1
# 遇到大值,把大值和未知区右边界进行交换,扩大大值区,继续检查刚交换过来的未知元素
else:
gt -= 1
arr[gt], arr[i] = arr[i], arr[gt]
# 把基准值和大值区左边界进行交换
arr[end], arr[gt] = arr[gt], arr[end]
# 至此:[start, lt]是小值,[lt+1, gt]是等值,[gt+1, end]是大值
return lt, gt + 1
用栈来保存递归函数的变参:左右两个子数组的边界标识。
分排函数( 快慢指针交换版)
不稳定原因
class Solution:
def MySort(self, arr):
"""迭代版"""
stack = [(0, len(arr)-1)]
while stack:
start, end = stack.pop()
mid = self.partition(arr, start, end)
if start < mid - 1:
stack.append((start, mid-1))
if mid + 1 < end:
stack.append((mid+1, end))
return arr
def partition(self, arr, start, end):
"""分排函数【快慢指针交换版 + 以end为基准】"""
pivot = arr[end]
slow = start - 1
for fast in range(start, end):
if arr[fast] <= pivot:
arr[fast], arr[slow+1] = arr[slow+1], arr[fast]
slow += 1
arr[end], arr[slow+1] = arr[slow+1], arr[end]
return slow + 1
注意
def partition(self, arr, start, end):
"""分排函数【快慢指针交换版 + 以start为基准】"""
pivot = arr[start]
slow = start
for fast in range(start+1, end+1):
if arr[fast] <= pivot:
arr[fast], arr[slow+1] = arr[slow+1], arr[fast]
slow += 1
# 注意:由于start是小值应该在的位置,则需与slow进行交换。
arr[start], arr[slow] = arr[slow], arr[start]
return slow
class Solution:
def MySort(self, arr):
"""递归版"""
self.quick_sort_recur(arr, 0, len(arr) - 1)
return arr
def quick_sort_recur(self, arr, start, end):
if start < end:
lt, gt = self.partition(arr, start, end)
self.quick_sort_recur(arr, start, lt)
self.quick_sort_recur(arr, gt, end)
def partition(self, arr, start, end):
"""分排函数【快慢指针三路版】"""
import random
pivot = random.randint(start, end)
arr[pivot], arr[end] = arr[end], arr[pivot]
fast = slow = start - 1
for i in range(start, end + 1):
# 遇到大值,无需任何交换,i 继续向右走,扩大大值区右边界
# 遇到等值,把等值与大值区左边界进行交换,扩大等值区右边界
if arr[i] == arr[end]:
fast += 1
arr[fast], arr[i] = arr[i], arr[fast]
# 遇到小值,需交换两次,才能把它交换到小值区,同时等值区整体右移
elif arr[i] < arr[end]:
fast += 1
arr[fast], arr[i] = arr[i], arr[fast]
slow += 1
arr[slow], arr[fast] = arr[fast], arr[slow]
return slow, fast + 1
三数取中(median-of-three)作为基准,改善分区的平衡性。
时间复杂度:降低了O(N * lgN)
的常数因子。
end-1
位置。end-1
。class Solution:
def MySort(self, arr):
"""迭代版"""
stack = [(0, len(arr) - 1)]
while stack:
start, end = stack.pop()
mid = self.partition(arr, start, end)
if start < mid - 1:
stack.append((start, mid - 1))
if mid + 1 < end:
stack.append((mid + 1, end))
return arr
def partition(self, arr, start, end):
"""分排函数【快慢指针交换版 + 三数取中为基准】"""
pviot_idx = self.swap_pivot(arr, start, end)
pivot = arr[pviot_idx]
slow = start - 1
for fast in range(start, pviot_idx):
if arr[fast] <= pivot:
arr[fast], arr[slow + 1] = arr[slow + 1], arr[fast]
slow += 1
arr[pviot_idx], arr[slow + 1] = arr[slow + 1], arr[pviot_idx]
return slow + 1
def swap_pivot(self, arr, start, end):
"""把首、中、尾三个元素排好序,再把中元素交换到 end-1 位置"""
mid = (start + end) // 2
if arr[start] > arr[mid]:
arr[start], arr[mid] = arr[mid], arr[start]
if arr[start] > arr[end]:
arr[start], arr[end] = arr[end], arr[start]
if arr[end] < arr[mid]:
arr[end], arr[mid] = arr[mid], arr[end]
arr[end-1], arr[mid] = arr[mid], arr[end-1]
return end-1
只需把分排函数的比较符号换成 >=
class Solution:
def MySort(self, arr):
stack = [(0, len(arr) - 1)]
while stack:
start, end = stack.pop()
mid = self.partition(arr, start, end)
if start < mid - 1:
stack.append((start, mid - 1))
if mid + 1 < end:
stack.append((mid + 1, end))
return arr
def partition(self, arr, start, end):
"""降序版"""
pivot = arr[end]
slow = start - 1
for fast in range(start, end):
if arr[fast] >= pivot:
arr[fast], arr[slow + 1] = arr[slow + 1], arr[fast]
slow += 1
arr[end], arr[slow + 1] = arr[slow + 1], arr[end]
return slow + 1
CPython 标准库函数 list.sort()
在 <= 2.2.3
的版本中,底层算法使用的是 SampleSort。而 >= 2.3.0
的版本中,底层算法使用的是 Timsort。这两种算法都是融合了多种算法的混合型算法。
SampleSort 算法可以视为一个“ N / lgN
取中”版本的快速排序。还可以将其视为一种桶排序,把数组元素分入 2^k
个桶,而桶边界(PP)是动态选择的。
时间复杂度:降低了O(N * lgN)
的常数因子。
2^k - 1
个元素作为预选枢轴(PP:preselected pivots),其中 k
等于最接近 N / lgN
的 2 幂次方,再减去 1 。把 PP 都交换到数组的最前面,并排好序。要实现的效果就是使用这个 2^k - 1
个 PP 把数组划分成 2^k
个区间。使得每个 PP 左边区间的元素都 < 该PP,右边的元素都 >= 该 PP。class SamplesortStackNode:
def __init__(self, stack_size):
self.lo = [0] * stack_size
self.hi = [0] * stack_size
self.extra = [0] * stack_size
class SampleSort:
MINSIZE = 100
MINPARTITIONSIZE = 40
MAXMERGE = 15
STACKSIZE = 60
CUTOFFBASE = 4
cutoff = [
43,
106,
250,
576,
1298,
2885,
6339,
13805,
29843,
64116,
137030,
291554,
617916,
1305130,
2748295,
5771662,
12091672,
25276798,
52734615,
109820537,
228324027,
473977813,
982548444,
2034159050
]
def binary_sort(self, arr, lo, hi, start):
assert lo <= start <= hi
if lo == start:
start += 1
while start < hi:
l = lo
r = start
pivot = arr[r]
while True:
p = l + ((r - l) >> 1)
if pivot < arr[p]:
r = p
else:
l = p + 1
if l >= r:
break
for p in range(start, l, -1):
arr[p] = arr[p-1]
arr[l] = pivot
start += 1
return arr
def sample_sort_slice(self, arr, lo, hi):
stack = SamplesortStackNode(self.STACKSIZE)
assert lo <= hi
n = hi - lo
"""
* ----------------------------------------------------------
* 特殊案例:已经有序、全部重复、在有序数组末尾追加几个随机元素。
* --------------------------------------------------------*
"""
if n < 2:
return arr
assert lo < hi
r = lo+1
for r in range(lo+1, hi):
if arr[r] < arr[r-1]:
break
if hi - r <= self.MAXMERGE or n < self.MINSIZE:
return self.binary_sort(arr, lo, hi, start=r)
assert lo < hi
for r in range(lo+1, hi):
if arr[r-1] < arr[r]:
break
if hi - r <= self.MAXMERGE:
originalr = r
l = lo
while True:
r -= 1
arr[l], arr[r] = arr[r], arr[l]
l += 1
if l >= r:
break
return self.binary_sort(arr, lo, hi, start=originalr)
"""
* ----------------------------------------------------------
* 普通案例: 没有明显模式的大型数组。
* --------------------------------------------------------
"""
extra = 0
for extra in range(0, len(self.cutoff)):
if n < self.cutoff[extra]:
break
assert self.MINSIZE >= 2 ** (self.CUTOFFBASE-1) - 1
# 1 << k == 2 ** k
extra = (1 << (extra - 1 + self.CUTOFFBASE)) - 1
assert (extra > 0) and (n >= extra)
seed = n // extra
for i in range(0, extra):
seed = seed * 69069 + 7
j = int(i + seed % (n - i))
arr[lo+i], arr[lo+j] = arr[lo+j], arr[lo+i] # tmp = lo[i]; lo[i] = lo[j]; lo[j] = tmp;
self.sample_sort_slice(arr, lo, lo + extra)
top = 0
lo += extra
extraOnRight = 0
"""
/* ----------------------------------------------------------
* 对 [lo, hi) 进行分区操作,重复此步骤,直到没有可处理的区间
* --------------------------------------------------------*/
"""
while True:
assert lo <= hi # so n >= 0
n = hi - lo
if n < self.MINPARTITIONSIZE or extra == 0:
if n >= self.MINSIZE:
assert extra == 0
self.sample_sort_slice(arr, lo, hi)
else:
if extraOnRight and extra:
k = extra
while True:
arr[lo], arr[hi] = arr[hi], arr[lo]
lo += 1
hi += 1
k -= 1
if k <= 0:
break
self.binary_sort(arr, lo-extra, hi, start=lo)
top -= 1
if top < 0:
break
lo = stack.lo[top]
hi = stack.hi[top]
extra = stack.extra[top]
extraOnRight = 0
if extra < 0:
extraOnRight = 1
extra = -extra
continue
extra >>= 1
k = extra
if extraOnRight:
while True:
arr[lo], arr[hi] = arr[hi], arr[lo]
lo += 1
hi += 1
k -= 1
if k < 0:
break
else:
while k > 0:
lo -= 1
hi -= 1
arr[lo], arr[hi] = arr[hi], arr[lo]
k -= 1
lo -= 1
pivot = arr[lo]
l = lo + 1
r = hi - 1
assert lo < l < r < hi
while True:
while True:
if arr[l] < pivot:
l += 1
else:
break
if l >= r:
break
while l < r:
rval = arr[r]
r -= 1
if rval < pivot:
arr[r+1] = arr[l]
arr[l] = rval
l += 1
break
if l >= r:
break
assert lo < r <= l < hi
assert r == l or r+1 == l
if l == r:
if arr[r] < pivot:
l += 1
else:
r -= 1
assert lo <= r and r+1 == l and l <= hi
assert r == lo or arr[r] < pivot
assert arr[lo] is pivot
assert l == hi or arr[l] >= pivot
arr[lo] = arr[r]
arr[r] = pivot
while l < hi:
if pivot < arr[l]:
break
else:
l += 1
assert lo <= r < l <= hi
assert top < self.STACKSIZE
if r - lo <= hi - l:
stack.lo[top] = l
stack.hi[top] = hi
stack.extra[top] = -extra
hi = r
extraOnRight = 0
else:
stack.lo[top] = lo
stack.hi[top] = r
stack.extra[top] = extra
lo = l
extraOnRight = 1
top += 1
return arr
class Solution:
def MySort(self, arr):
return SampleSort().sample_sort_slice(arr, 0, len(arr))
平均时间复杂度:O(N^2)
最差时间复杂度:O(N^2)
最优时间复杂度:O(N^2)
原址性:原址,只有交换元素时需要在原数组之外临时存储一个元素。
稳定性:不稳定,旧的最小值和新的最小值交换时,可能改变等值元素的相对顺序。
class Solution:
def MySort(self, arr):
n = len(arr)
for i in range(n):
min_index = i
for j in range(i+1, n):
if arr[j] < arr[min_index]:
min_index = j
if min_index != i:
arr[min_index], arr[i] = arr[i], arr[min_index]
return arr
class Solution:
def MySort(self, arr):
n = len(arr)
for i in range(n-1, 0, -1):
max_idx = i
for j in range(0, i):
if arr[j] > arr[max_idx]:
max_idx = j
if max_idx != i:
arr[max_idx], arr[i] = arr[i], arr[max_idx]
return arr
只需把比较运算符号换成 >,变量名称 min_index 改不改都没有影响。
class Solution:
def MySort(self, arr):
"""降序版"""
n = len(arr)
for i in range(n):
min_index = i
for j in range(i+1, n):
if arr[j] > arr[min_index]:
min_index = j
if min_index != i:
arr[min_index], arr[i] = arr[i], arr[min_index]
return arr
优点:使用完全二叉树保存上一轮的选择过程,下轮选择时利用这些信息减少比较次数,从而降低了时间复杂度,由O(N^2)
降到O(N * lgN)
。
缺点: 辅助空间使用多,与无穷大的比较多余。(堆排序弥补了这些缺点)
Tournament Tree:叶子节点是参赛者,非叶子节点是赢家。
2 ^ k
,k 是满足2 ^ k >= n
的最小值)、总节点数量(叶子节点数量 * 2 - 1
)。注意树列表下标从1开始,忽略下标0,否则后续计算父节点下标会略麻烦。class Solution:
def MySort(self, arr):
"""选最小值,树下标从1开始,稳定排序"""
n = len(arr)
max_float = float("inf")
base_size = 1
while base_size < n:
base_size *= 2
tree_size = base_size * 2 - 1
tree = [0] * (tree_size + 1)
# 填充叶子节点
i = -1
while i >= -n:
tree[i] = arr[i]
i -= 1
while i >= -base_size:
tree[i] = max_float
i -= 1
# 构建树
i = tree_size
while i > 0:
if tree[i] < tree[i-1]:
tree[i // 2] = tree[i]
else:
tree[i // 2] = tree[i-1]
i -= 2
# 排序:拿走最小值并调整树,奇数下标是右节点,当左右节点相等时,总是选择左节点,从而保证稳定排序
sorted_idx = 0
while sorted_idx < n:
min_val = tree[1]
arr[sorted_idx] = min_val
sorted_idx += 1
min_idx = tree_size
while tree[min_idx] != min_val:
min_idx -= 1
tree[min_idx] = max_float
while min_idx > 1:
if min_idx % 2 == 0:
if tree[min_idx] <= tree[min_idx+1]:
tree[min_idx // 2] = tree[min_idx]
else:
tree[min_idx // 2] = tree[min_idx+1]
else:
if tree[min_idx] < tree[min_idx-1]:
tree[min_idx // 2] = tree[min_idx]
else:
tree[min_idx // 2] = tree[min_idx-1]
min_idx //= 2
return arr
平均时间复杂度:O(N * lgN)
最差时间复杂度:O(N * lgN)
详见《算法导论》6.4 章节
最优时间复杂度:O(N * lgN)
原址性:原址,只有交换元素时需要在原数组之外临时存储一个元素。
稳定性:不稳定,排序时需要把 0 节点元素与end 节点元素交换,可能破环end 节点与其等值元素的相对次序。
构建最大堆
堆排序
调整堆函数有三种写法:递归版、迭代版(栈)、迭代版(变量更新)。
调整堆函数
class Solution:
def MySort(self, arr):
n = len(arr)
# 【一】构建最大堆
for root in range(n-1, -1, -1):
self.adjust_heap(arr, root, n-1)
# 【二】堆排序
for end in range(n-1, 0, -1):
arr[0], arr[end] = arr[end], arr[0]
self.adjust_heap(arr, 0, end-1)
return arr
def adjust_heap(self, arr, root, end):
"""递归版"""
child = root * 2 + 1
if child <= end:
if child+1 <= end and arr[child+1] > arr[child]:
child += 1
if arr[child] > arr[root]:
arr[child], arr[root] = arr[root], arr[child]
self.adjust_heap(arr, child, end)
调整堆函数可以用栈来实现,即把递归函数的变参 root 放到栈里,while循环的下轮再弹出,直到栈空结束。
class Solution:
def MySort(self, arr):
n = len(arr)
# 【一】构建最大堆
for root in range(n-1, -1, -1):
self.adjust_heap(arr, root, n-1)
# 【二】堆排序
for end in range(n-1, 0, -1):
arr[0], arr[end] = arr[end], arr[0]
self.adjust_heap(arr, 0, end-1)
return arr
def adjust_heap(self, arr, root, end):
"""迭代版:栈"""
stack = [root]
while stack:
root = stack.pop()
child = root * 2 + 1
if child <= end:
if child+1 <= end and arr[child+1] > arr[child]:
child += 1
if arr[child] > arr[root]:
arr[child], arr[root] = arr[root], arr[child]
stack.append(child)
把递归调用自己时的变参 root 更新一下,while 循环的下轮继续使用,需要显示声明循环结束的2种情形:根结点比两个一层子节点都大;根结点没有子节点了。
class Solution:
def MySort(self, arr):
n = len(arr)
# 【一】构建最大堆
for root in range(n-1, -1, -1):
self.adjust_heap(arr, root, n-1)
# 【二】堆排序
for end in range(n-1, 0, -1):
arr[0], arr[end] = arr[end], arr[0]
self.adjust_heap(arr, 0, end-1)
return arr
def adjust_heap(self, arr, root, end):
"""迭代版:变量更新"""
while True:
child = root * 2 + 1
if child > end:
break
if child + 1 <= end and arr[child + 1] > arr[child]:
child += 1
if arr[child] > arr[root]:
arr[child], arr[root] = arr[root], arr[child]
root = child
else:
break
注意
构建最大堆时:可以先计算出最后一个非叶子节点下标,从最后一个非叶子节点到 0 节点,依次作为根结点,调整成合法堆。
不计算也没什么影响,因为调整堆函数中会先判断是否有子节点,没有就直接返回了。
class Solution:
def MySort(self, arr):
n = len(arr)
# 【一】构建最大堆
last_no_leaf = int(n/2 - 1)
for root in range(last_no_leaf, -1, -1):
self.adjust_heap(arr, root, n-1)
降序需要构建最小堆,其实只需把调整堆函数中的两个比较符号换成 <
class Solution:
def MySort(self, arr):
n = len(arr)
# 【一】构建最小堆
for root in range(n-1, -1, -1):
self.adjust_heap(arr, root, n-1)
# 【二】堆排序
for end in range(n-1, 0, -1):
arr[0], arr[end] = arr[end], arr[0]
self.adjust_heap(arr, 0, end-1)
return arr
def adjust_heap(self, arr, root, end):
"""降序版"""
child = root * 2 + 1
if child <= end:
if child+1 <= end and arr[child+1] < arr[child]:
child += 1
if arr[child] < arr[root]:
arr[child], arr[root] = arr[root], arr[child]
self.adjust_heap(arr, child, end)
平均时间复杂度:O(N^2)
详见《算法导论》2.2 章节
最差时间复杂度:O(N^2)
最优时间复杂度:O(N)
原址性:原址,只有交换元素时需要在原数组之外临时存储一个元素。
稳定性:稳定,新元素只会插入到其等值元素的后面。
class Solution:
def MySort(self , arr ):
"""排升序"""
for new in range(1, len(arr)):
while new-1 >= 0 and arr[new] < arr[new-1]:
arr[new], arr[new-1] = arr[new-1], arr[new]
new -= 1
return arr
class Solution:
def MySort(self, arr):
"""挖坑版"""
for i in range(1, len(arr)):
new_val = arr[i]
while i - 1 >= 0 and new_val < arr[i - 1]:
arr[i] = arr[i - 1]
i -= 1
arr[i] = new_val
return arr
只需把新旧牌的比较符号换成 >
class Solution:
def MySort(self, arr):
"""降序版"""
for new in range(1, len(arr)):
while new - 1 >= 0 and arr[new] > arr[new-1]:
arr[new], arr[new-1] = arr[new-1], arr[new]
new -= 1
return arr
二分查找能减少比较次数,不影响移动次数,虽然能够提高查找效率,但时间复杂度不变。 详解
class Solution:
def MySort(self, arr):
"""折半插入排序"""
for i in range(1, len(arr)):
new_val = arr[i]
left = 0
right = i - 1
while left <= right:
mid = (left + right) // 2
if new_val < arr[mid]:
right = mid - 1
else:
left = mid + 1
for j in range(i, left, -1):
arr[j] = arr[j-1]
arr[left] = new_val
return arr
2-路插入排序在二分插入排序的基础上减少移动次数,虽然能够同时提高查找和移动效率,但时间复杂度不变。
class Solution:
def MySort(self, arr):
first = 0
last = 0
tmp_list = [0] * len(arr)
tmp_list[0] = arr[0]
for i in range(1, len(arr)):
new_val = arr[i]
if new_val < tmp_list[first]:
first -= 1
tmp_list[first] = new_val
elif new_val >= tmp_list[last]:
last += 1
tmp_list[last] = new_val
else:
left = first
right = last - 1
while left <= right:
mid = (left + right) // 2
if new_val < tmp_list[mid]:
right = mid - 1
else:
left = mid + 1
last += 1
for j in range(last, left, -1):
tmp_list[j] = tmp_list[j - 1]
tmp_list[left] = new_val
return [tmp_list[idx] for idx in range(first, last+1)]
希尔排序是多轮步长递减的插入排序。
由于:
O(N)
复杂度,效率极高。希尔排序一开始的数据移动幅度很大,逐渐降到1,提高了【场景2】下的效率,又利用了【场景1】中的优势。
平均时间复杂度:根据增量序列的不同而不同,大体范围O(N^M),1 < M < 2
,详见 排序算法之希尔排序及其增量序列
最差时间复杂度:根据增量序列的不同而不同,大体范围O(N^M),1 < M < 2
最优时间复杂度:O(N)
原址性:原址,只有交换元素时需要在原数组之外临时存储一个元素。
稳定性:不稳定,虽然一次插入排序是稳定的,但等值元素会在不同增量的插入排序中移动,其相对次序可能被打乱。
希尔排序提出时,其原始增量序列为:{1,2,4,8,...}
,最差时间复杂度为O(N^2)
。
gap = len(arr) // 2
。gap //= 2
。class Solution:
def MySort(self, arr):
gap = len(arr) // 2
while gap >= 1:
for new in range(gap, len(arr)):
while new-gap >= 0 and arr[new] < arr[new-gap]:
arr[new], arr[new-gap] = arr[new-gap], arr[new]
new -= gap
gap //= 2
return arr
Hibbard提出了另一个增量序列{1,3,7,...,2^k-1}
,最差时间复杂度为O(N^1.5)
,平均时间复杂度约为O(N^1.25)
。
class Solution:
def MySort(self, arr):
gap_list = self.get_gap_list(len(arr))
for gap in reversed(gap_list):
for new in range(gap, len(arr)):
while new - gap >= 0 and arr[new] < arr[new - gap]:
arr[new], arr[new - gap] = arr[new - gap], arr[new]
new -= gap
return arr
def get_gap_list(self, n):
k = 0
gap_list = list()
while True:
gap = 2 ** k - 1
if gap <= n:
gap_list.append(gap)
else:
break
k += 1
return gap_list
Sedgewick提出了几种增量序列,最差时间复杂度为O(N^1.33)
,平均时间复杂度约为O(N^1.17)
,其中最好的一个序列是{1,5,19,41,109,...}
,其生成序列是D = 9 * 4^i - 9 * 2^i + 1
或 4^(i+2) - 3 * 2^(i+2) + 1
,其中 i>=0
。
class Solution:
def MySort(self, arr):
gap_list = self.get_gap_list(len(arr))
for gap in reversed(gap_list):
for new in range(gap, len(arr)):
while new - gap >= 0 and arr[new] < arr[new - gap]:
arr[new], arr[new - gap] = arr[new - gap], arr[new]
new -= gap
return arr
def get_gap_list(self, n):
i = 0
gap_list = list()
while True:
gap = 9 * (4**i) - 9 * (2**i) + 1
if gap <= n:
gap_list.append(gap)
gap = 4 ** (i+2) - 3 * 2 ** (i+2) + 1
if gap <= n:
gap_list.append(gap)
else:
break
i += 1
return gap_list
跟插入排序一样,只需把新旧牌的比较符号换成 >
class Solution:
def MySort(self, arr):
gap = len(arr) // 2
while gap >= 1:
for new in range(gap, len(arr)):
while new-gap >= 0 and arr[new] > arr[new-gap]:
arr[new], arr[new-gap] = arr[new-gap], arr[new]
new -= gap
gap //= 2
return arr
平均时间复杂度:O(N * lgN)
最差时间复杂度:O(N * lgN)
详见《算法导论》2.3.2 章节
最优时间复杂度:O(N * lgN)
原址性:非原址,合并时需要在原数组之外临时存储输出数组。
稳定性:稳定,合并时遇到等值元素总是选择左子区间的。
class Solution:
def MySort(self, arr):
"""自上而下的递归"""
if len(arr) == 1:
return arr
middle_idx = len(arr) // 2
return self.merge(self.MySort(arr[:middle_idx]), self.MySort(arr[middle_idx:]))
def merge(self, left, right):
"""合并函数:pop(0)写法"""
ret = list()
while left and right:
if left[0] <= right[0]:
ret.append(left.pop(0))
else:
ret.append(right.pop(0))
return ret + left + right
class Solution:
def MySort(self, arr):
"""自下而上的迭代"""
step = 1
while step < len(arr):
ls = 0
while ls + step < len(arr):
le = rs = ls + step
re = rs + step
if re > len(arr):
re = len(arr)
arr[ls:re] = self.merge(arr[ls:le], arr[rs:re])
ls += step * 2
step *= 2
return arr
def merge(self, left, right):
"""合并函数:指针写法"""
ret = list()
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
ret.append(left[i])
i += 1
else:
ret.append(right[j])
j += 1
return ret + left[i:] + right[j:]
只需把 merge 函数的比较符号换成 >=
class Solution:
def MySort(self, arr):
"""降序版"""
if len(arr) == 1:
return arr
middle_idx = len(arr) // 2
return self.merge(self.MySort(arr[:middle_idx]), self.MySort(arr[middle_idx:]))
def merge(self, left, right):
"""合并函数:pop(0)写法"""
ret = list()
while left and right:
if left[0] >= right[0]:
ret.append(left.pop(0))
else:
ret.append(right.pop(0))
return ret + left + right
TimSort 是一种自适应的、稳定的、自然的 MergeSort。
时间复杂度:降低了O(N * lgN)
的常数因子。
CPython 标准库函数 list.sort()
在 <= 2.2.3
的版本中,底层算法使用的是 SampleSort。而 >= 2.3.0
的版本中,底层算法使用的是 Timsort。
run :数组中本身存在的有序区间,如果是降序则会就地转成升序。
gallop 查找法:先通过指数查找法(1、3、7、15、31、63 … 2^n-1),快速找到正确位置的大致范围,然后在这个这个范围中用二分查找法锁定最终的正确位置。
Python 版 TimSort 算法
Timsort 介绍(listsort.txt 翻译)
class Solution:
def __init__(self):
"""初始化阈值、堆栈"""
self.MIN_MERGE = 32
self.MIN_GALLOP = 7
self.min_gallop = self.MIN_GALLOP
self.stack_size = 0
self.run_base = []
self.run_len = []
def MySort(self, arr):
"""排序方法"""
lo = 0
hi = len(arr)
assert (arr is not None) and (lo >= 0) and (lo <= hi) and (hi <= len(arr))
n_remaining = hi - lo
if n_remaining < 2:
return arr
if n_remaining < self.MIN_MERGE:
init_run_len = self.countRunAndMakeAscending(arr, lo, hi)
self.binarySort(arr, lo, hi, lo + init_run_len)
return arr
min_run = self.minRunLength(n_remaining)
while n_remaining:
run_len = self.countRunAndMakeAscending(arr, lo, lo + n_remaining)
if run_len < min_run:
force = n_remaining if n_remaining <= min_run else min_run
self.binarySort(arr, lo=lo, hi=lo+force, start=lo+run_len)
run_len = force
self.pushRun(lo, run_len)
self.mergeCollapse(arr)
lo += run_len
n_remaining -= run_len
assert lo == hi
self.mergeForceCollapse(arr)
assert self.stack_size == 1
return arr
def binarySort(self, arr, lo, hi, start):
"""对数组的指定区间进行二分插入排序,通过 start 参数能跳过已知有序区,直接从无序区开始排序"""
assert (lo <= start) and (start <= hi)
if lo == start:
start += 1
while start < hi:
pivot = arr[start]
left = lo
right = start
assert left <= right
while left < right:
mid = (left + right) >> 1
if pivot < arr[mid]:
right = mid
else:
left = mid + 1
assert left == right
n = start - left
arr[left + 1:left + 1 + n] = arr[left:left + n]
arr[left] = pivot
start += 1
def countRunAndMakeAscending(self, arr, lo, hi):
"""从数组的指定位置开始寻找有序部分作为初始 run,如果是降序则就地转成升序,返回有序部分的长度"""
assert lo < hi
run_hi = lo + 1
if run_hi == hi:
return 1
if arr[run_hi] < arr[lo]:
run_hi += 1
while run_hi < hi and arr[run_hi] < arr[run_hi - 1]:
run_hi += 1
arr[lo:run_hi] = reversed(arr[lo:run_hi])
else:
while run_hi < hi and arr[run_hi] >= arr[run_hi - 1]:
run_hi += 1
return run_hi - lo
def minRunLength(self, n):
"""根据待排序元素个数计算 min_run(最小 run 长度)"""
assert n >= 0
r = 0
while n >= self.MIN_MERGE:
r |= n & 1
n >>= 1
return n + r
def pushRun(self, run_base, run_len):
"""把 run 的起始位置和长度这两个变量放入堆栈"""
if len(self.run_base) <= self.stack_size:
self.run_base.append(run_base)
self.run_len.append(run_len)
else:
self.run_base[self.stack_size] = run_base
self.run_len[self.stack_size] = run_len
self.stack_size += 1
def mergeCollapse(self, arr):
"""检查堆栈中的 run,如果符合以下条件则进行归并:
1)如果堆栈中有三个及以上的 run,则归并条件为:倒三长度 <= 倒二长度 + 倒一长度
2)如果堆栈中只有两个 run,则归并条件为:倒二长度 <= 倒一长度
"""
while self.stack_size > 1:
n = self.stack_size - 2
if n > 0 and self.run_len[n-1] <= self.run_len[n] + self.run_len[n+1]:
if self.run_len[n-1] < self.run_len[n+1]:
n -= 1
self.mergeAt(arr, n)
elif self.run_len[n] <= self.run_len[n+1]:
self.mergeAt(arr, n)
else:
break
def mergeForceCollapse(self, arr):
"""归并堆栈中的所有 run,直到只剩一个,这个方法只会调用一次来完成排序"""
while self.stack_size > 1:
n = self.stack_size - 2
if n > 0 and self.run_len[n - 1] < self.run_len[n + 1]:
n -= 1
self.mergeAt(arr, n)
def mergeAt(self, arr, i):
"""归并堆栈中的第 i 个和第 i+1 个 run。i 必须是 run 堆栈中的倒数第二或倒数第三个下标"""
assert self.stack_size >= 2
assert i >= 0
assert i == self.stack_size - 2 or i == self.stack_size - 3
base1 = self.run_base[i]
len1 = self.run_len[i]
base2 = self.run_base[i + 1]
len2 = self.run_len[i + 1]
assert len1 > 0 and len2 > 0
assert base1 + len1 == base2
self.run_len[i] = len1 + len2
if i == self.stack_size - 3:
self.run_base[i + 1] = self.run_base[i + 2]
self.run_len[i + 1] = self.run_len[i + 2]
self.stack_size -= 1
"""分别缩小两个 run 的待归并区间"""
k = self.gallopRight(key=arr[base2], arr=arr, base=base1, len_=len1, hint=0)
assert k >= 0
base1 += k
len1 -= k
if len1 == 0:
return
len2 = self.gallopLeft(key=arr[base1+len1-1], arr=arr, base=base2, len_=len2, hint=len2-1)
assert len2 >= 0
if len2 == 0:
return
"""归并这两个已缩小的 run"""
if len1 <= len2:
self.mergeLo(arr, base1, len1, base2, len2)
else:
self.mergeHi(arr, base1, len1, base2, len2)
def gallopLeft(self, key, arr, base, len_, hint):
"""在数组的待搜索区间中,查找指定值(key)应当插入的正确位置;如果存在与 key 等值的元素,则返回是这些相等值中最左边的位置"""
assert (len_ > 0) and (hint >= 0) and (hint < len_)
last_ofs = 0
ofs = 1
if key > arr[base + hint]:
max_ofs = len_ - hint
while ofs < max_ofs and key > arr[base + hint + ofs]:
last_ofs = ofs
ofs = (ofs << 1) + 1
if ofs <= 0:
ofs = max_ofs
if ofs > max_ofs:
ofs = max_ofs
last_ofs += hint
ofs += hint
else:
max_ofs = hint + 1
while ofs < max_ofs and key <= arr[base + hint - ofs]:
last_ofs = ofs
ofs = (ofs << 1) + 1
if ofs <= 0:
ofs = max_ofs
if ofs > max_ofs:
ofs = max_ofs
tmp = last_ofs
last_ofs = hint - ofs
ofs = hint - tmp
assert (-1 <= last_ofs) and (last_ofs < ofs) and (ofs <= len_)
last_ofs += 1
while last_ofs < ofs:
m = last_ofs + ((ofs - last_ofs) >> 1)
if key > arr[base + m]:
last_ofs = m + 1
else:
ofs = m
assert last_ofs == ofs
return ofs
def gallopRight(self, key, arr, base, len_, hint):
"""与 gallopRight 相似,区别在于:如果存在与 key 等值的元素,则返回是这些相等值中最右边的位置"""
assert (len_ > 0) and (hint >= 0) and (hint < len_)
last_ofs = 0
ofs = 1
if key < arr[base + hint]:
max_ofs = hint + 1
while ofs < max_ofs and key < arr[base + hint - ofs]:
last_ofs = ofs
ofs = (ofs << 1) + 1
if ofs <= 0:
ofs = max_ofs
if ofs > max_ofs:
ofs = max_ofs
tmp = last_ofs
last_ofs = hint - ofs
ofs = hint - tmp
else:
max_ofs = len_ - hint
while ofs < max_ofs and key >= arr[base + hint + ofs]:
last_ofs = ofs
ofs = (ofs << 1) + 1
if ofs <= 0:
ofs = max_ofs
if ofs > max_ofs:
ofs = max_ofs
last_ofs += hint
ofs += hint
assert (-1 <= last_ofs) and (last_ofs < ofs) and (ofs <= len_)
last_ofs += 1
while last_ofs < ofs:
m = last_ofs + ((ofs - last_ofs) >> 1)
if key > arr[base + m]:
last_ofs = m + 1
else:
ofs = m
assert last_ofs == ofs
return ofs
def mergeLo(self, arr, base1, len1, base2, len2):
"""从左端的最小值开始检查,按小值合并两个相邻的 run,保持排序的稳定性"""
assert len1 > 0 and len2 > 0 and base1 + len1 == base2
tmp = []
cursor1 = 0
cursor2 = base2
dest = base1
tmp[cursor1:cursor1+len1] = arr[base1:base1+len1]
arr[dest] = arr[cursor2]
dest += 1
cursor2 += 1
len2 -= 1
if len2 == 0:
arr[dest:dest+len1] = tmp[cursor1:cursor1+len1]
return
if len1 == 1:
arr[dest:dest+len2] = arr[cursor2:cursor2+len2]
arr[dest+len2] = tmp[cursor1]
return
min_gallop = self.min_gallop
break_outer = False
while True:
count1 = 0
count2 = 0
"""【普通归并模式】"""
while True:
assert len1 > 1 and len2 > 0
if arr[cursor2] < tmp[cursor1]:
arr[dest] = arr[cursor2]
dest += 1
cursor2 += 1
count2 += 1
count1 = 0
len2 -= 1
if len2 == 0:
break_outer = True
break
else:
arr[dest] = tmp[cursor1]
dest += 1
cursor1 += 1
count1 += 1
count2 = 0
len1 -= 1
if len1 == 1:
break_outer = True
break
if (count1 | count2) >= min_gallop:
break
if break_outer:
break
"""【GALLOP 模式】"""
while True:
assert len1 > 1 and len2 > 0
count1 = self.gallopRight(key=arr[cursor2], arr=tmp, base=cursor1, len_=len1, hint=0)
if count1 != 0:
arr[dest:dest+count1] = tmp[cursor1:cursor1+count1]
dest += count1
cursor1 += count1
len1 -= count1
if len1 <= 1:
break_outer = True
break
arr[dest] = arr[cursor2]
dest += 1
cursor2 += 1
len2 -= 1
if len2 == 0:
break_outer = True
break
count2 = self.gallopLeft(key=tmp[cursor1], arr=arr, base=cursor2, len_=len2, hint=0)
if count2 != 0:
arr[dest:dest+count2] = arr[cursor2:cursor2+count2]
dest += count2
cursor2 += count2
len2 -= count2
if len2 == 0:
break_outer = True
break
arr[dest] = tmp[cursor1]
dest += 1
cursor1 += 1
len1 -= 1
if len1 == 1:
break_outer = True
break
min_gallop -= 1
if not (count1 >= self.MIN_GALLOP | count2 >= self.MIN_GALLOP):
break
if break_outer:
break
if min_gallop < 0:
min_gallop = 0
min_gallop += 2
self.min_gallop = 1 if min_gallop < 1 else min_gallop
if len1 == 1:
assert len2 > 0
arr[dest:dest+len2] = arr[cursor2:cursor2+len2]
arr[dest + len2] = tmp[cursor1]
elif len1 == 0:
raise Exception("IllegalArgument")
else:
assert len2 == 0
assert len1 > 1
arr[dest:dest+len1] = tmp[cursor1:cursor1+len1]
def mergeHi(self, arr, base1, len1, base2, len2):
"""与 mergeLo 相似,区别在于从右端的最大值开始检查,按大值合并"""
assert len1 > 0 and len2 > 0 and base1 + len1 == base2
tmp = []
tmp_base = 0
tmp[tmp_base:tmp_base+len2] = arr[base2:base2+len2]
cursor1 = base1 + len1 - 1
cursor2 = tmp_base + len2 - 1
dest = base2 + len2 - 1
arr[dest] = arr[cursor1]
dest -= 1
cursor1 -= 1
len1 -= 1
if len1 == 0:
arr[dest+1-len2:dest+1] = tmp[tmp_base:tmp_base+len2]
return
if len2 == 1:
dest -= 1
cursor1 -= len1
arr[dest+1:dest+1+len1] = arr[cursor1+1:cursor1+1+len1]
arr[dest] = tmp[cursor2]
return
min_gallop = self.min_gallop
break_outer = False
while True:
count1 = 0
count2 = 0
"""【普通归并模式】"""
while True:
assert len1 > 0 and len2 > 1
if tmp[cursor2] < arr[cursor1]:
arr[dest] = arr[cursor1]
dest -= 1
cursor1 -= 1
count1 += 1
count2 = 0
len1 -= 1
if len1 == 0:
break_outer = True
break
else:
arr[dest] = tmp[cursor2]
dest -= 1
cursor2 -= 1
count2 += 1
count1 = 0
len2 -= 1
if len2 == 1:
break_outer = True
break
if (count1 | count2) >= min_gallop:
break
if break_outer:
break
""""【GALLOP 模式】"""
while True:
assert len1 > 0 and len2 > 1
count1 = len1 - self.gallopRight(key=tmp[cursor2], arr=arr, base=base1, len_=len1, hint=len1-1)
if count1 != 0:
dest -= count1
cursor1 -= count1
len1 -= count1
arr[dest+1:dest+1+count1] = arr[cursor1+1:cursor1+1+count1]
if len1 == 0:
break_outer = True
break
arr[dest] = tmp[cursor2]
dest -= 1
cursor2 -= 1
len2 -= 1
if len2 == 1:
break_outer = True
break
count2 = len2 - self.gallopLeft(key=arr[cursor1], arr=tmp, base=tmp_base, len_=len2, hint=len2-1)
if count2 != 0:
dest -= count2
cursor2 -= count2
len2 -= count2
arr[dest+1:dest+1+count2] = tmp[cursor2+1:cursor2+1+count2]
if len2 <= 1:
break_outer = True
break
arr[dest] = arr[cursor1]
dest -= 1
cursor1 -= 1
len1 -= 1
if len1 == 0:
break_outer = True
break
min_gallop -= 1
if not (count1 >= self.MIN_GALLOP | count2 >= self.MIN_GALLOP):
break
if break_outer:
break
if min_gallop < 0:
min_gallop = 0
min_gallop += 2
self.min_gallop = 1 if min_gallop < 1 else min_gallop
if len2 == 1:
assert len1 > 0
dest -= len1
cursor1 -= len1
arr[dest+1:dest+1+len1] = arr[cursor1+1:cursor1+1+len1]
arr[dest] = tmp[cursor2]
elif len2 == 0:
raise Exception("IllegalArgumentException")
else:
assert len1 == 0
assert len2 > 0
arr[dest+1-len2:dest+1] = tmp[tmp_base:tmp_base+len2]
以下四种非比较排序算法都基于【分桶+合并桶】的思想,也称基于分配的排序算法。《算法导论》在第 8 章(线性时间排序)中讨论了这些算法。
可以视为计数排序的简易版,是不稳定的排序算法,等值元素的相对次序无法区分。
class Solution:
def MySort(self, arr):
"""返回新数组"""
max_val = max(arr)
count_list = [0] * (max_val+1)
ret = list()
for val in arr:
count_list[val] += 1
for val in range(len(count_list)):
if count_list[val] != 0:
ret.extend([val] * count_list[val])
return ret
为了缩短 count_list,按最小值进行偏移:
class Solution:
def MySort(self, arr):
"""覆盖原数组 + 偏移最小值"""
max_val = max(arr)
min_val = min(arr)
count_list = [0] * (max_val - min_val + 1)
sorted_idx = 0
for val in arr:
count_list[val - min_val] += 1
for val in range(len(count_list)):
while count_list[val] > 0:
arr[sorted_idx] = val + min_val
sorted_idx += 1
count_list[val] -= 1
return arr
假设前提:输入元素都是 [0,K]
区间内的整数,K 为某个整数。
平均时间复杂度:O(N + K)
,当 K = O(N)
时,计数排序是线性运行时间O(N)
,详见《算法导论》8.2 章节。
最差时间复杂度:O(N + K)
最优时间复杂度:O(N + K)
所需辅助空间:O(N + K)
,需要长度为 N 的空间保存输出数组,长度为 K 的空间保存桶列表本身。
稳定性:稳定,使用累加和能唯一标识每个元素的最终排序位置,配合倒序遍历原数组,使得等值元素的相对次序不变。
class Solution:
def MySort(self, arr):
max_val = max(arr)
min_val = min(arr)
count_list = [0] * (max_val - min_val + 1)
ret = [0] * len(arr)
for val in arr:
count_list[val - min_val] += 1
for i in range(1, len(count_list)):
count_list[i] += count_list[i - 1]
for i in range(len(arr) - 1, -1, -1):
sorted_idx = count_list[arr[i] - min_val] - 1
ret[sorted_idx] = arr[i]
count_list[arr[i] - min_val] -= 1
return ret
由于序时 count_list 值代表的是最终排序位置,用 n - 1 减此升序位置就得到了降序位置。
class Solution:
def MySort(self, arr):
max_val = max(arr)
min_val = min(arr)
count_list = [0] * (max_val - min_val + 1)
ret = [0] * len(arr)
for val in arr:
count_list[val - min_val] += 1
for i in range(1, len(count_list)):
count_list[i] += count_list[i - 1]
for i in range(len(arr) - 1, -1, -1):
sorted_idx = (len(arr) - 1) - (count_list[arr[i] - min_val] - 1)
ret[sorted_idx] = arr[i]
count_list[arr[i] - min_val] -= 1
return ret
平均时间复杂度:O(N)
,即使输入数据不服从均匀分布,只要满足:所有桶的大小的平方和与总的元素数量呈线性关系,桶排序仍然可以在线性时间内完成,详见《算法导论》8.4 章节。
最差时间复杂度:O(N^2)
最优时间复杂度:O(N)
所需辅助空间:O(N+M)
,N 个输入元素分别要拷贝到各个桶内,总共 M 个桶。
稳定性:取决于桶内元素排序时使用的算法,使用插入排序和自递归时稳定,使用快速排序不稳定。
bucket_size = (max_val - min_val) / (bucket_num - 1)
),或者根据预定的桶大小确定桶数量(bucket_num = (max_val - min_val) // bucket_size + 1
),初始化 bucket_list。class Solution:
def MySort(self, arr):
max_val = max(arr)
min_val = min(arr)
bucket_num = len(arr)
bucket_size = (max_val - min_val) / (bucket_num - 1)
bucket_list = [[] for _ in range(bucket_num)]
for val in arr:
bucket_idx = int((val - min_val) / bucket_size)
bucket_list[bucket_idx].append(val)
# 返回的是新数组
ret = list()
for bucket in bucket_list:
bucket = self.insertSort(bucket)
ret.extend(bucket)
return ret
def insertSort(self, arr):
for new in range(1, len(arr)):
while new - 1 >= 0 and arr[new] < arr[new-1]:
arr[new], arr[new-1] = arr[new-1], arr[new]
new -= 1
return arr
class Solution:
def MySort(self, arr):
max_val = max(arr)
min_val = min(arr)
if max_val == min_val or len(arr) == 1:
return arr
bucket_num = len(arr)
bucket_size = (max_val - min_val) / (bucket_num - 1)
bucket_list = [[] for _ in range(bucket_num)]
for val in arr:
bucket_idx = int((val - min_val) / bucket_size)
bucket_list[bucket_idx].append(val)
# 覆盖原数组
sorted_start_idx, sorted_end_idx = 0, 0
for bucket in bucket_list:
if bucket:
sorted_end_idx += len(bucket)
arr[sorted_start_idx:sorted_end_idx] = self.MySort(bucket)
sorted_start_idx = sorted_end_idx
return arr
只需把合并桶的顺序改成倒着遍历
class Solution:
def MySort(self, arr):
"""降序版"""
max_val = max(arr)
min_val = min(arr)
if max_val == min_val or len(arr) == 1:
return arr
bucket_num = len(arr)
bucket_size = (max_val - min_val) / (bucket_num - 1)
bucket_list = [[] for _ in range(bucket_num)]
for val in arr:
bucket_idx = int((val - min_val) / bucket_size)
bucket_list[bucket_idx].append(val)
# 降序需要倒着合并桶
sorted_start_idx, sorted_end_idx = 0, 0
for i in range(len(bucket_list)-1, -1, -1):
if bucket_list[i]:
sorted_end_idx += len(bucket_list[i])
arr[sorted_start_idx:sorted_end_idx] = self.MySort(bucket_list[i])
sorted_start_idx = sorted_end_idx
return arr
平均时间复杂度:O(d * (N+K) )
,d
是位数,其中每一个位有 K
个可能的取值。N+K
表示按位排序时使用的是O(N+K)
的稳定排序算法(例如,计数排序)。当 d
是常数且 K = O(N)
时,基数排序是线性运行时间O(N)
,详见《算法导论》8.3 章节。
最差时间复杂度:O(d * (N+K) )
最优时间复杂度:O(d * (N+K) )
所需辅助空间:O(N + K)
,需要长度为 N 的空间在桶列表内保存每一轮的按位排序结果,长度为 K 的空间保存桶列表本身。
稳定性:如果按位排序时使用的是稳定排序算法(例如,计数排序),则基数排序稳定。
LSD:Least Significant Digit first
class Solution:
def MySort(self, arr):
"""最低位优先LSD"""
max_val = max(arr)
max_digit = 1
while max_val >= 10 ** max_digit:
max_digit += 1
for digit in range(max_digit):
bucket_list = [[] for _ in range(10)]
for val in arr:
bucket_idx = int(val / (10**digit) % 10)
bucket_list[bucket_idx].append(val)
sorted_idx = 0
for bucket in bucket_list:
for val in bucket:
arr[sorted_idx] = val
sorted_idx += 1
return arr
MSD:Most Significant Digit first
class Solution:
def MySort(self, arr):
"""最高位优先MSD"""
if len(arr) <= 1:
return arr
max_val = max(arr)
max_digit = 1
while max_val >= 10 ** max_digit:
max_digit += 1
return self.sort_digit_recur(arr, max_digit-1)
def sort_digit_recur(self, arr, digit):
if digit < 0:
return arr
bucket_list = [[] for _ in range(10)]
for val in arr:
bucket_idx = int(val / (10 ** digit) % 10)
bucket_list[bucket_idx].append(val)
ret = list()
for bucket in bucket_list:
if bucket:
sorted_bucket = self.sort_digit_recur(bucket, digit - 1)
ret.extend(sorted_bucket)
return ret
只需把每轮合并桶的顺序改成倒着遍历
class Solution:
def MySort(self, arr):
"""降序版"""
max_val = max(arr)
max_digit = 1
while max_val >= 10 ** max_digit:
max_digit += 1
for digit in range(max_digit):
bucket_list = [[] for _ in range(10)]
for val in arr:
bucket_idx = int(val / (10**digit) % 10)
bucket_list[bucket_idx].append(val)
sorted_idx = 0
for i in range(len(bucket_list)-1, -1, -1):
for val in bucket_list[i]:
arr[sorted_idx] = val
sorted_idx += 1
return arr