分治法是一种非常通用的算法设计技巧. 在很多实际问题中, 相比直接求解, 分治法往往能显著降低算法的计算复杂度. 常见的可以用分治法求解的问题有: 排序, 矩阵乘法, 整数乘法, 离散傅里叶变换等. 分治法的一般思路如下:
- [Divide] 把原问题拆分成同类型的子问题.
- [Conquer] 用递归的方式求解子问题.
- [Combine] 根据子问题的解构造原问题的解.
分治法最关键的步骤是如何 低成本地 利用子问题的解构来造原问题的解. 它包含两个方面: 1. 可行性, 即, 可以用子问题的解来构造原问题的解; 2. 高效性, 即构造原问题解的时间复杂度较低. 换句话说, 分治法需要比直接求解效率高. 分治法一般是通过递归求解子问题, 其时间复杂度的分析需要用到如下定理.
Master Theorem[1]. 考虑递归式. 当, 时, 我们有
Counting Inversions[2]
在推荐场景中, 一个常用的方法是协同过滤(Collaborative Filtering), 即为相似的用户推荐它们共同喜欢的事物. 以推荐歌曲为例, 我们把用户和对歌曲的偏好分别进行排序, 然后计算有多少首歌在和的排序是"不同的". 最后根据这种不同来定义用户和的相似性, 从而进行歌曲推荐. 具体来说, 对用户对歌曲的偏好按编号. 用户对歌曲的偏好可以表示为
对任意的, 如果, 我们称和称为 反序(inversion). 换句话说, 用户A把歌曲排在歌曲的前面, 但是用户B把歌曲排在了歌曲的前面.
问题描述
给定无重复数字的整数序列, 计算其反序的数量.
算法设计
直接求解的思路是考虑所有的二元组并判断它们是否反序. 这个算法的时间复杂度是. 下面我们用分治法来降低计算复杂度.
注意到这个问题实际上与排序非常类似. 通过对序列进行排序的同时记录不满足"顺序"的二元组数量, 即反序的数量. 令. 把序列分成两部分:
用递归的方式对left和right进行排序, 同时计算left和right中的反序数量. 当序列的长度为1时, 返回0. 下一步是合并子问题的解. 注意到left和right已经是按照从小到大的顺序进行排列. 比较left和right的第一个元素并把较小的元素添加到结果中直到left或right为空, 最后再把剩余的序列添加到结果集. 在比较过程中我们需要记录反序的数量. 当right中的元素小于left中的元素时, 反序的增量为"left中剩余元素的数量". 最终的结果包含三部分之和: left中反序的数量, rihgt中反序的数量和合并时反序的数量.
Python实现
整体的计算过程.
def sort_and_count(x):
if len(x) == 1:
return x, 0
k = len(x) // 2
left, count_left = sort_and_count(x[0: k])
right, count_right = sort_and_count(x[k:])
# 把子问题的解拼接成原问题的解
combined, count = merge_and_count(left, right)
return combined, count + count_left + count_right
归并过程.
def merge_and_count(left, right):
""" 把left和right合并且计算inversion的数量
注意: left和right已经排好序
"""
combined = []
count = 0
while len(left) and len(right):
if left[0] > right[0]: # 反序(左边的编号小于右边的编号是正序)
combined.append(right.pop(0))
count += len(left)
else: # 正序
combined.append(left.pop(0))
return combined + left + right, count
完整代码
计算复杂度
容易分析归并过程的时间复杂度是. 令代表算法的时间复杂度, 我们有
根据Master Theorem, 我们得到.
Closest Pair[2]
Closest Pair是计算几何里的一个基本问题: 给定二维平面上的个点, 找到距离最近的两个点. 通过计算任意两点的距离可以在找到距离最近的两点. 下面利用分治法可以把时间复杂度降低到.
算法设计
如果所有点是一维的, 我们可以把它们排序, 然后计算所有相邻两点的最小距离. 排序耗时, 计算相邻点的最小距离耗时, 因此算法的时间复杂度为. 在二维情形, 我们的思路是类似的:
沿着轴方向对点集进行排序得到.
把按与轴垂的方向均分成两部分和:
-
递归地求解和中的closest pair(如下图所示).
根据和的计算结果构造原问题的解(见下文).
合并(Combine)
设, 分别是和中的closest pair. 如果的closest pair在或中, 我们只需要从和选择距离小的pair作为结果输出. 否则的closest pair其中一点在中, 另一点在中, 这时我们需要比较和中的点. 这样一来, 合并的时间复杂度为! 接下来我们要把合并的时间复杂度降低为.
令, 其中代表, 两点之间的距离. 设 代表和的分割线. 如果存在, 使得, 那么和在轴方向距离一定不超过. 令, 因此. 如下图所示, 中的点在蓝色虚线之间.
[图片上传失败...(image-facf8c-1586865336187)]
把中的点按轴从小到大排序, 得到集合, 其中是一个二元组(代表它在平面中的位置). 我们有如下定理(稍后给出证明):
定理 如果存在满足, 那么.
这样一来, 我们可以在的时间内找到所有距离不超过的点对, 并记录距离最小的点对作为结果输出(如果存在). 思路思路如下:
pairs_within_delta = [] # S中距离不超过delta的点的集合
for s in Sy:
for t in 15 points after s:
if d(s, t) < delta:
add (s,t) to pairs_within_delta
output the minimum distance pair in pairs_within_delta
求解子问题和之前, 首先把根据轴从小到大排序得到, 这样一来可以在时间内构造, 即依次过滤掉中距离超过的点. 在上述算法中, 外层循环次数是, 内层循环是常数, 因此在合并步骤中构造closest pair的时间复杂度最终降低为.
Python实现
先把输入点集分别按轴和轴排序, 得到和. 递归求解的过程参考函数closest_pair_xy
.
def closest_pair(points):
""" 计算二维点集中的closest pair.
:param points: P = [(x1,y1), (x2,y2), ..., (xn, yn)]
:return: 两个距离最近的点
"""
# 把P按x轴和y轴分别进行排序, 得到Px和Py
# 注意: P, Px, Py 三个集合是相同的(仅仅排序不同)
Px = sorted(points, key=lambda item: item[0])
Py = sorted(points, key=lambda item: item[1])
return closest_pair_xy(Px, Py)
def closest_pair_xy(Px, Py):
""" 计算closest pair
:param Px: 把points按x轴升序排列
:param Py: 把points按y轴升序排列
:return: point1, point2
"""
if len(Px) <= 3:
return search_closest_pair(Px)
# 构造子问题的输入: Qx, Rx, Qy, Ry
k = len(Px) // 2
Q, R = Px[0: k], Px[k:]
Qx, Qy = sorted(Q, key=lambda x: x[0]), sorted(Q, key=lambda x: x[1])
Rx, Ry = sorted(R, key=lambda x: x[0]), sorted(R, key=lambda x: x[1])
# 求解子问题
q0, q1 = closest_pair_xy(Qx, Qy)
r0, r1 = closest_pair_xy(Rx, Ry)
# 合并子问题的解
return combine_results_of_sub_problems(Py, Qx, q0, q1, r0, r1)
def search_closest_pair(points):
""" 用枚举的方式寻找closest pair
:param points: [(x1,y1), (x2,y2), ...]
:return: closest pair
"""
n = len(points)
dist_min, i_min, j_min = math.inf, 0, 0
for i in range(n-1):
for j in range(i+1, n):
d = get_dist(points[i], points[j])
if d < dist_min:
dist_min, i_min, j_min = d, i, j
return points[i_min], points[j_min]
def get_dist(a, b):
""" 计算两点之间的欧式距离
"""
return math.sqrt(math.pow(a[0]-b[0], 2) + math.pow(a[1]-b[1], 2))
设代表closest_pair_xy
的计算时间. 根据前文分析, 合并子问题的解并输出原问题的解的时间复杂度为, 因此我们有
根据Master Theorem, 我们有. 此外, 把分别按轴排序的时间复杂度为, 因此算法整体的时间复杂度是.
下面是合并过程的实现.
def combine_results_of_sub_problems(Py, Qx, q0, q1, r0, r1):
"""
:param Py: P按y轴排序的结果
:param Qx: P在x=x0处被切分成Q和R. Qx是Q按x轴排序的结果
:param q0: (q0, q1)是Q中的closest pair
:param q1: 参考q0
:param r0: (r0, r1)是R中的closest pair
:param r1: 参考r0
:return: closest pair in P
"""
# 计算Sy
d = min(get_dist(q0, q1), get_dist(r0, r1))
Sy = get_sy(Py, Qx, d)
# 检查是否存在距离更小的pair
s1, s2 = closest_pair_of_sy(Sy)
if s1 and s2 and get_dist(s1, s2) < d:
return s1, s2
elif get_dist(q0, q1) < get_dist(r0, r1):
return q0, q1
else:
return r0, r1
def get_sy(Py, Qx, d):
""" 根据Py计算Sy.
:param Py: P按y轴排序的结果
:param Qx: P在x=x0处被切分成Q和R. Qx是Q按x轴排序的结果
:param d: delta
:return: S
"""
x0 = Qx[-1][0] # Q最右边点的x坐标值
return [p for p in Py if p[0] - x0 < d]
def closest_pair_of_sy(Sy):
""" 计算集合Sy的closest pair
"""
n = len(Sy)
if n <= 1:
return None, None
dist_min, i_min, j_min = math.inf, 0, 0
for i in range(n-1):
for j in range(i + 1, i + 16):
if j == len(Sy):
break
d = get_dist(Sy[i], Sy[j])
if d < dist_min:
dist_min, i_min, j_min = d, i, j
return Sy[i_min], Sy[j_min]
完整代码
定理证明
定理 如果存在满足, 那么.
根据前文的描述, 已知中的点在下图蓝色虚线之间. 把中的点按轴从小到大排序得到, 其中代表平面中的一个点. 为了方便描述, 我们把下图中蓝色虚线之间用单位长度为的网格划分.
[图片上传失败...(image-dad441-1586865336187)]
假设存在使得. 我们要证明. 证明分为两步:
- 和必须在不同的网格中. 反证法. 假设在同一个网格中(意味着 or ), 它们的距离. 注意是和中的最短距离, 因此矛盾.
- . 反证法. 假设. 如上图所示和之间至少相差3行(网格). 因此, 矛盾.
参考文献
-
T.H. Cormen, C. E. Leiserson, R.L. Rivest and C. Stein. Introduction to Algorithms. Third edition. The MIT Press, 2009. ↩
-
J. Kleinberg and E. Tardos. Algorithm Design. Pearson, 2006. ↩ ↩