倍增是一种优化复杂度的思想,通过把区间压缩到二进制下标的方式,可以大量的合并信息。这也要求区域内的贡献通常是均匀的。
查询时,把路径用二进制分解,那么就可以快速到达目标。
通常,由于初始化需要nlogn的时间,需求应该是离线的。
for i in range(m-1): # 外层优先遍历步数2^i
for u in range(n): # 内层转移节点
p=pa[u][i]; # 当前节点的父节点
pa[u][i + 1] = pa[p][i]; # 那么从u跨两个2^i步就到达p的2^i步
f[u][i+1] = f[u][i]+f[p][i] # 从u跨两个2^i,就是从u夸一次2^i,从p夸一次。
u = s = 0
for j in range(k.bit_length()):
if k>>j&1:
x = pa[u][j]
s += f[u][j]
m = k.bit_length()
,代表把k步压二进制最多压成这么多位。注意转移时只要转移range(m-1)。例题: 1483. 树节点的第 K 个祖先
class TreeAncestor:
def __init__(self, n: int, parent: List[int]):
m = n.bit_length()
self.pa = pa = [[-1]*m for _ in range(n)]
for u,fa in enumerate(parent[1:],start=1):
pa[u][0] = fa
for i in range(m-1):
for u in range(n):
if (p:=pa[u][i]) != -1:
pa[u][i+1] = pa[p][i]
def getKthAncestor(self, u: int, k: int) -> int:
for i in range(k.bit_length()):
if k>>i&1:
u = self.pa[u][i]
if u == -1:
break
return u
def get_lca(self, x: int, y: int) -> int:
"""返回 x 和 y 的最近公共祖先(节点编号从 0 开始)
思路是先让x,y处于同一层,通过kth跳。
然后尝试迈大步(2^i步),若迈完发现变成同节点就不迈了,尝试2^(i-1)步。
最后答案pa[x][0],即x、y一定在lca的直接儿子上,"""
if self.depth[x] > self.depth[y]:
x, y = y, x
# 使 y 和 x 在同一深度
y = self.get_kth_ancestor(y, self.depth[y] - self.depth[x])
if y == x:
return x
for i in range(len(self.pa[x]) - 1, -1, -1):
px, py = self.pa[x][i], self.pa[y][i]
if px != py:
x, y = px, py # 同时上跳 2**i 步
return self.pa[x][0]
链接: 957. N 天后的牢房
class Solution:
def prisonAfterNDays(self, cells: List[int], n: int) -> List[int]:
s = int(''.join(map(str,cells)),2)
m = n.bit_length()
f = [[-1]*m for _ in range(1<<8)]
for i in range(1<<8):
p = 0
for j in range(1,7):
if (i>>(j-1)&1) == (i>>(j+1)&1):
p |= 1<<j
f[i][0] = p
for i in range(m-1):
for j in range(1<<8):
p = f[j][i]
f[j][i+1] = f[p][i]
for j in range(m):
if n>>j&1:
s = f[s][j]
ans = bin(s)[2:]
ans = '0'*(8-len(ans))+ans
return [int(c) for c in ans]
链接: 2836. 在传球游戏中最大化函数值
class Solution:
def getMaxFunctionValue(self, receiver: List[int], k: int) -> int:
n = len(receiver)
m = k.bit_length()
f = [receiver] + [[0]*n for _ in range(m+1)]
pa = [receiver] + [[0]*n for _ in range(m+1)]
for i in range(m-1):
for j in range(n):
p = pa[i][j]
pa[i+1][j] = pa[i][p]
f[i+1][j] = f[i][p] + f[i][j]
ans = 0
for i,v in enumerate(receiver):
s = i
for j in range(k.bit_length()):
if k >> j&1:
s += f[j][i]
i = pa[j][i]
ans = max(ans,s)
return ans