TimSort 中的核心过程

TimSort 是 Python 中 list.sort 的默认实现。Java 7 也将非原始类型列表的排序实现替换成了 TimSort。网上关于 TimSort 是什么,性能特点分析的文章不少,但是介绍它的具体实现步骤的文章很少。这里有一篇:[url=http://www.drmaciver.com/2010/01/understanding-timsort-1adaptive-mergesort/]Understanding timsort, Part 1: Adaptive Mergesort[/url],用 C 作为示例代码。

基于这个文章的介绍,我用 python 实现一遍 TimSort,并说一下其中的关键步骤。因为原文只讲解了 TimSort 中最基本最重要的部分,所以本文也没有超过这个范围。本文不是对 TimSort 的分析,只是介绍一下其基本实现。

[size=large]TimSort 概览[/size]
TimSort 是一个归并排序做了大量优化的版本。对归并排序排在已经反向排好序的输入时表现O(n^2)的特点做了特别优化。对已经正向排好序的输入减少回溯。对两种情况混合(一会升序,一会降序)的输入处理比较好。

[size=large]TimSort 核心过程[/size]
假定,我们的 TimSort 是进行升序排序。TimSort 为了减少对升序部分的回溯和对降序部分的性能倒退,将输入按其升序和降序特点进行了分区。排序的输入的单位不是一个个单独的数字了,而一个个的分区。其中每一个分区我们叫一个“run“。针对这个 run 序列,每次我们拿一个 run 出来进行归并。每次归并会将两个 runs 合并成一个 run。归并的结果保存到 "run_stack" 上。如果我们觉得有必要归并了,那么进行归并,直到消耗掉所有的 runs。这时将 run_stack 上剩余的 runs 归并到只剩一个 run 为止。这时这个仅剩的 run 即为我们需要的排好序的结果。


def timsort(arr):
arr = arr or []
if len(arr) <= 0: return []
runs = _partition_to_runs(arr)
run_stack = []
for run in runs:
run_stack.append(run)
while _should_merge(run_stack):
_merge_stack(run_stack)
while len(run_stack) > 1:
_merge_stack(run_stack)
return run_stack[0]


这里“觉得有必要”这句话很模糊,到底什么时候有必要后面会给出定义。

[size=large]如何分区[/size]
为了在已经按升序排好序的输入面前减少回溯,我们把输入当中已经有序的这些段分组,使得它们成为一个基本单元,这样我们就不必在这个基本单元内部浪费时间进行回溯了。比如[1, 2, 3, 2] 进行分区后就变成了 [[1, 2, 3], [2]]。

为了在已经按降序排好序的输入面前避免归并排序倒退成 O(n^2),我们把输入当中降序的部分翻转成升序,也作为一个单元。比如 [3, 2, 1, 3] 进行分区后就变成了 [[1, 2, 3], [3]]。


def _partition_to_runs(arr):
partitioned_up_to = 0
while partitioned_up_to < len(arr):
if not len(arr) - partitioned_up_to:
return
if len(arr) - partitioned_up_to == 1:
part = list(arr[-1:])
partitioned_up_to += 1
yield part
else:
if arr[partitioned_up_to] > arr[partitioned_up_to + 1]: # 这里必须是严格降序
next_pos = _find_desc_boundary(arr, partitioned_up_to)
_reverse(arr, partitioned_up_to, next_pos)
else:
next_pos = _find_asc_boundary(arr, partitioned_up_to)

part = arr[partitioned_up_to:next_pos]
partitioned_up_to = next_pos
yield part

def _find_desc_boundary(arr, start):
if start >= len(arr) - 1:
return start + 1
if arr[start] > arr[start+1]: # 这里必须是严格降序
return _find_desc_boundary(arr, start + 1)
else:
return start + 1

def _reverse(arr, start=0, end=None):
# 正常的翻转函数,实现省略

def _find_asc_boundary(arr, start):
if start >= len(arr) - 1:
return start + 1
if arr[start] <= arr[start+1]:
return _find_asc_boundary(arr, start + 1)
else:
return start + 1


这里注意降序的部分必须是“严格”降序才能进行翻转。因为 TimSort 的一个重要目标是保持稳定性(stability)。如果在 >= 的情况下进行翻转这个算法就不再是 stable sorting algorithm 了。

[size=large]逆向分解[/size]
传统的归并排序是通过递归,用函数栈把每次 "divide" 的结果保存下来的。divide 的最终结果是一个个的基本单元-单个数字。但是我们看到 TimSort 把这个过程反过来了。我们经过一次分区,已经拿到了了基本单元列表,只不过这次基本单元是一串数字。所以我们只能自己手工将将基本单元列表进行合并。

[size=large]如何合并[/size]
那么何时进行合并?合并的策略是要在 "run_stack" 上维护一个不变式。当这个不变式被打破时即进行合并。传统的归并排序通过二分法可以保证函数栈的深度为 log(n)。我们也模拟这个策略,也让 run_stack 的长度不超过 log(n)。假如 runN 先入栈,runN+1 紧随其后入栈。那么就要求 runN 的长度要是 runN+1 长度的 2 倍。所以归并的条件是:如果 runN 的长度 < (runN+1 的长度 * 2) 即进行归并。


# 因为我们每次新添 run 进入 run_stack 时都判断是否需要归并,
# 并且在每次归并之后还要进一步确保 run_stack 是满足不变式的,
# 所以这里只判断栈头的两个 run 就够了。
def _should_merge(run_stack):
if len(run_stack) < 2:
return False
return len(run_stack[-2]) < 2*len(run_stack[-1])

def _merge(ls1, ls2):
# 正常的归并函数,实现省略

def _merge_stack(run_stack):
head = run_stack.pop()
next = run_stack.pop()
new_run = _merge(next, head)
run_stack.append(new_run)


跟分区的情况类似,这里在归并的时候也要用 stable merge。

[size=large]插入排序优化[/size]
到上面的步骤为止,程序已经可以正确地排序了。但是我们知道插入排序在输入元素数小于一个阀值的时候相比其它排序会更快,所以很多排序算法在 divide 这一步进行到只剩不到这个阀值个数的元素的时候会改用插入排序(比如 JDK6 的快排,参考[url=http://www.blogjava.net/killme2008/archive/2010/09/08/quicksort_optimized.html]这里[/url]),所以我们也要做这个优化。

在分区的时候,如果我们观察到新产生出来的 run 的长度小于适用于插入排序的阀值,我们就用插入排序把这个 run 的长度扩充到这个阀值。


def _partition_to_runs(arr):
partitioned_up_to = 0
while partitioned_up_to < len(arr):
if not len(arr) - partitioned_up_to:
return
if len(arr) - partitioned_up_to == 1:
part = list(arr[-1:])
partitioned_up_to += 1
yield part
else:
if arr[partitioned_up_to] > arr[partitioned_up_to + 1]:
next_pos = _find_desc_boundary(arr, partitioned_up_to)
_reverse(arr, partitioned_up_to, next_pos)
else:
next_pos = _find_asc_boundary(arr, partitioned_up_to)

# 只加了这一句话
next_pos = _do_insertion_sort_optimization(arr, partitioned_up_to, next_pos)

part = arr[partitioned_up_to:next_pos]
partitioned_up_to = next_pos
yield part

def _insertion_sort(arr, start, end):
# 标准插入排序实现

def _do_insertion_sort_optimization(arr, start, end):
length = end - start
if length < INSERTION_SORT_THRESHOLD:
end = min(start+INSERTION_SORT_THRESHOLD, len(arr))
_insertion_sort(arr, start, end)
return end


这里我们只加一句话就够了。剩余的就是标准的插入排序实现。

[size=large]与原文代码的差异[/size]
TimSort 最多使用 O(n) 临时内存空间。由于原文是 C 的代码,为了减少 malloc 的次数而一次性分配了 O(n) 的数组空间。我们这里因为是用 python,也这么做会显得很怪异。所以内存是在每次归并的时候一点点分配的。

TimSort 的实现逻辑上可以看成分区和归并两部分。但由于 C 不支持协程,而 python 通过 generator 部分支持协程。所以为了提高可读性,分区的部分我是用 generator 的方式做的。在代码上与归并的部分完全分离。而原文为了达到 lazy 的目的,是一边分区一边归并的。

[size=large]完整的实现和测试代码[/size]

# -*- coding: utf-8 -*-
import functools
from unittest import TestCase

INSERTION_SORT_THRESHOLD = 6

def _find_desc_boundary(arr, start):
if start >= len(arr) - 1:
return start + 1
if arr[start] > arr[start+1]:
return _find_desc_boundary(arr, start + 1)
else:
return start + 1

def _reverse(arr, start=0, end=None):
if end is None:
end = len(arr)
for i in range(start, start + (end-start)//2):
opposite = end - i - 1
arr[i], arr[opposite] = arr[opposite], arr[i]

def _find_asc_boundary(arr, start):
if start >= len(arr) - 1:
return start + 1
if arr[start] <= arr[start+1]:
return _find_asc_boundary(arr, start + 1)
else:
return start + 1

def _insertion_sort(arr, start, end):
if end - start <= 1:
return
for i in range(start, end):
v = arr[i]
j = i - 1
while j>=0 and arr[j] > v:
arr[j+1] = arr[j]
j -= 1
arr[j+1] = v

def _do_insertion_sort_optimization(arr, start, end):
length = end - start
if length < INSERTION_SORT_THRESHOLD:
end = min(start+INSERTION_SORT_THRESHOLD, len(arr))
_insertion_sort(arr, start, end)
return end

def _partition_to_runs(arr):
partitioned_up_to = 0
while partitioned_up_to < len(arr):
if not len(arr) - partitioned_up_to:
return
if len(arr) - partitioned_up_to == 1:
part = list(arr[-1:])
partitioned_up_to += 1
yield part
else:
if arr[partitioned_up_to] > arr[partitioned_up_to + 1]:
next_pos = _find_desc_boundary(arr, partitioned_up_to)
_reverse(arr, partitioned_up_to, next_pos)
else:
next_pos = _find_asc_boundary(arr, partitioned_up_to)

next_pos = _do_insertion_sort_optimization(arr, partitioned_up_to, next_pos)

part = arr[partitioned_up_to:next_pos]
partitioned_up_to = next_pos
yield part

def _should_merge(run_stack):
if len(run_stack) < 2:
return False
return len(run_stack[-2]) < 2*len(run_stack[-1])

def _merge(ls1, ls2, merge_storage=None):
ret = merge_storage or []
i1 = 0
i2 = 0
while i1 < len(ls1) and i2 < len(ls2):
a = ls1[i1]
b = ls2[i2]
if a <= b:
ret.append(a)
i1 += 1
else:
ret.append(b)
i2 += 1
ret += ls1[i1:]
ret += ls2[i2:]
return ret

def _merge_stack(run_stack, merge_storage=None):
head = run_stack.pop()
next = run_stack.pop()
new_run = _merge(next, head, merge_storage=merge_storage)
run_stack.append(new_run)

def timsort(arr):
arr = arr or []
if len(arr) <= 0: return []
runs = _partition_to_runs(arr)
run_stack = []
for run in runs:
run_stack.append(run)
while _should_merge(run_stack):
_merge_stack(run_stack)
while len(run_stack) > 1:
_merge_stack(run_stack)
return run_stack[0]

class Test(TestCase):
class Elem:
seq_no = 0
def __init__(self, n):
Elem = Test.Elem
self.n = n
self.seq_no = Elem.seq_no
Elem.seq_no += 1

def __lt__(self, other):
return self.n < other.n

def __str__(self):
return "E" + str(self.n) + "S" + str(self.seq_no)
Elem = functools.total_ordering(Elem)

def setUp(self):
Test.Elem.seq_no = 0

def test_reverse(self):
arr = [3, 2, 1, 4, 7, 5, 6]
_reverse(arr)
self.assertEquals(arr, [6, 5, 7, 4, 1, 2, 3])

arr = [3, 2, 1]
_reverse(arr)
self.assertEquals(arr, [1, 2, 3])

def test_find_asc_boundary(self):
arr = [1, 2, 3, 3, 2]
self.assertEqual(_find_asc_boundary(arr, 0), 4)

arr = [1, 2, 3, 3]
self.assertEqual(_find_asc_boundary(arr, 0), 4)

def test_find_desc_boundary(self):
arr = [3, 2, 1]
self.assertEqual(_find_desc_boundary(arr, 0), 3)

arr = [3, 2, 1, 1]
self.assertEqual(_find_desc_boundary(arr, 0), 3)

def test_merge_stack(self):
arr1 = [1, 2, 3]
arr2 = [2, 3, 4]
stack = [arr1, arr2]
_merge_stack(stack)
self.assertEqual(stack, [[1, 2, 2, 3, 3, 4]])

def test_merge_stability(self):
Elem = Test.Elem
arr1 = map(lambda e: Elem(e), [1, 2, 3])
arr2 = map(lambda e: Elem(e), [2, 3, 4])
stack = [arr1, arr2]
_merge_stack(stack)
self.assertEqual(map(lambda lst: map(str, lst), stack), [['E1S0', 'E2S1', 'E2S3', 'E3S2', 'E3S4', 'E4S5']])

def test_timsort(self):
Elem = Test.Elem
arr = map(lambda e: Elem(e), [3, 1, 2, 2, 7, 5])
ret = timsort(arr)
self.assertEquals(map(str, ret), ['E1S1', 'E2S2', 'E2S3', 'E3S0', 'E5S5', 'E7S4'])

self.assertEqual(timsort([]), [])
self.assertEqual(timsort(None), [])

你可能感兴趣的:(算法,动态语言)