字典树通常用来处理字符串前缀查找问题,不过很多题都可以用cpp_map水过去
在n个字符串words[i]里,执行m次查找
例题: 648. 单词替换
题意找出每个单词在字典中的最短前缀。
今天的每日一题,看别人代码发现有nb的实现方案,速度又快还节省空间,还代码短。
class Solution:
def replaceWords(self, dictionary: List[str], sentence: str) -> str:
trie = {}
for s in dictionary:
pos = trie
for c in s:
if c not in pos:
pos[c] = {}
pos = pos[c]
pos['is_end'] = s
def find(s):
pos = trie
for c in s:
if c not in pos:
return s
pos = pos[c]
if 'is_end' in pos:
return pos['is_end']
return s
return ' '.join(map(find,sentence.split()))
例题: 6183. 字符串的前缀分数和
311周赛T4.只能说很水,但之前没有写记cnt的模板,因此比赛时候调了一会。
这题求每个字符串前缀在整个数组中出现过的总次数作为分数,因此find的时候累加cnt即可。
[LeetCode周赛复盘] 第 311 场周赛20220918
class Solution:
def sumPrefixScores(self, words: List[str]) -> List[int]:
trie = {}
for s in words:
pos = trie
for c in s:
if c not in pos:
pos[c] = {}
pos[c]['cnt'] = 0
pos = pos[c]
pos['cnt'] += 1
pos['end'] = s
# print(trie)
ans = []
def find(s):
pos = trie
ret = 0
for c in s:
if c not in pos:
return ret
ret += pos[c]['cnt']
pos = pos[c]
# if 'end' in pos:
# return
return ret
for s in words:
ans.append(find(s))
return ans
例题: 211. 添加与搜索单词 - 数据结构设计
这题由于点代表任意字符,所以要dfs,写成了树
class TrieNode:
def __init__(self,cnt=0):
self.cnt = cnt
self.next = [None]*26
self.is_end = False
class WordDictionary:
def __init__(self):
self.root = TrieNode()
self.depth = 0
def addWord(self, word: str) -> None:
cur = self.root
for c in word:
i = ord(c)-ord('a')
if not cur.next[i] : # 没有这个字符
cur.next[i] = TrieNode()
cur = cur.next[i]
cur.cnt += 1
cur.is_end = True
def search(self, word: str) -> bool:
def dfs(root,word,start):
if not root:
return False
if start == len(word):
return root.is_end
c = word[start]
if c == '.':
for i in range(26):
if root.next[i] and root.next[i].cnt >0 and dfs(root.next[i],word,start+1):
return True
return False
else :
i = ord(c)-ord('a')
if not root.next[i] or root.next[i].cnt == 0:
return False
return dfs(root.next[i],word,start+1)
return dfs(self.root,word,0)
例题: 1268. 搜索推荐系统
字典树
树实现经典案例
class TrieNode:
def __init__(self,cnt=0):
self.cnt = cnt
self.next = [None]*26
self.is_end = False
def insert(self, word: str) -> None:
cur = self
for c in word:
i = ord(c)-ord('a')
if not cur.next[i] : # 没有这个字符
cur.next[i] = TrieNode()
cur = cur.next[i]
cur.cnt += 1
cur.is_end = True
def find(self,word):
ans = []
cur = self
now_word=''
for c in word:
now_word+=c
i = ord(c)-ord('a')
if not cur.next[i]:
return ans
cur = cur.next[i]
def dfs(root,now_word):
if len(ans) >= 3:
return ans
if root.is_end:
ans.append(now_word)
for i in range(26):
c = chr(ord('a')+i)
if root.next[i]:
dfs(root.next[i],now_word+c)
dfs(cur,now_word)
return ans
class Solution:
def suggestedProducts(self, products: List[str], searchWord: str) -> List[List[str]]:
trie = TrieNode()
for product in products:
trie.insert(product)
ans = []
for i in range(len(searchWord)):
ans.append(trie.find(searchWord[:i+1]))
return ans
字典树
数组实现-动态开数组提交后竟然没有树快
class Trie:
def __init__(self,c_cnt=26):
self.c_cnt=c_cnt
self.trie = [[0]*c_cnt]
self.is_end = [False]
self.w = [0] # 这个节点被经过了几次
def insert(self,word):
trie = self.trie
u = 0
for c in word:
i = ord(c)-ord('a')
if trie[u][i] == 0:
trie.append([0]*self.c_cnt)
trie[u][i] = len(trie)-1
self.is_end.append(False)
self.w.append(0)
u = trie[u][i]
self.w[u] += 1
self.is_end[u] = True
def find(self,word):
# 查找这个前缀的字符串有几个
trie = self.trie
u = 0
for c in word:
i = ord(c)-ord('a')
if trie[u][i] == 0:
return 0
u = trie[u][i]
return self.w[u]
def find_ans(self,word):
trie = self.trie
ans = []
u = 0
now_word=''
for c in word:
now_word+=c
i = ord(c)-ord('a')
if trie[u][i] == 0:
return ans
u = trie[u][i]
def dfs(u,now_word):
if len(ans) >= 3:
return ans
if self.is_end[u]:
ans.append(now_word)
for i in range(26):
c = chr(ord('a')+i)
if trie[u][i] > 0:
dfs(trie[u][i],now_word+c)
dfs(u,now_word)
return ans
class Solution:
def suggestedProducts(self, products: List[str], searchWord: str) -> List[List[str]]:
trie = Trie()
for product in products:
trie.insert(product)
ans = []
for i in range(len(searchWord)):
ans.append(trie.find_ans(searchWord[:i+1]))
return ans
字典树
数组实现-预先开数组数组开辟长度应该=len*m,len是单词数,m是最长单词长度,提交后竟然没有动态开数组快
class Trie:
def __init__(self,size,c_cnt=26):
self.c_cnt=c_cnt
self.trie = [[0]*c_cnt for _ in range(size)]
self.is_end = [False] * (size)
self.w = [0] * (size) # 这个节点被经过了几次
self.total_cnt = 1
def insert(self,word):
trie = self.trie
u = 0
for c in word:
i = ord(c)-ord('a')
if trie[u][i] == 0:
trie[u][i] = self.total_cnt
self.total_cnt += 1
u = trie[u][i]
self.w[u] += 1
self.is_end[u] = True
def find(self,word):
# 查找这个前缀的字符串有几个
trie = self.trie
u = 0
for c in word:
i = ord(c)-ord('a')
if trie[u][i] == 0:
return 0
u = trie[u][i]
return self.w[u]
def find_ans(self,word):
trie = self.trie
ans = []
u = 0
now_word=''
for c in word:
now_word+=c
i = ord(c)-ord('a')
if trie[u][i] == 0:
return ans
u = trie[u][i]
def dfs(u,now_word):
if len(ans) >= 3:
return ans
if self.is_end[u]:
ans.append(now_word)
for i in range(26):
c = chr(ord('a')+i)
if trie[u][i] > 0:
dfs(trie[u][i],now_word+c)
dfs(u,now_word)
return ans
class Solution:
def suggestedProducts(self, products: List[str], searchWord: str) -> List[List[str]]:
# ans.append(trie.find(searchWord[:i+1]))
max_len = max([len(s) for s in products])
trie = Trie(max_len*len(products)+1)
for product in products:
trie.insert(product)
ans = []
for i in range(len(searchWord)):
ans.append(trie.find_ans(searchWord[:i+1]))
return ans
链接: 421. 数组中两个数的最大异或值
01字典树
经典案例,贪心的去找子节点
class TrieXor:
def __init__(self,nums = None):
# 01字典树,用来处理异或最值问题,本模板只处理数字最低的31位
# 用nums初始化字典树,可以为空
self.tree = {}
if nums:
for a in nums:
self.insert(a)
def insert(self,num):
# 01字典树插入一个数字num,只会处理最低31位。
cur = self.tree
for i in range(31,-1,-1):
nxt = (num>>i)&1
if nxt not in cur:
cur[nxt] = {}
cur = cur[nxt]
def find_max_xor_num(self,num):
# 计算01字典树里任意数字异或num的最大值,只会处理最低31位。
# 贪心的从高位开始处理,显然num的某位是0,对应的优先应取1;相反同理
cur = self.tree
ret = 0
for i in range(31,-1,-1):
if (num>>i)&1 == 0: # 如果本位是0,那么取1才最大;取不到1才取0
if 1 in cur:
cur = cur[1]
ret += ret + 1
else:
cur = cur.get(0,{})
ret <<= 1
else:
if 0 in cur:
cur = cur[0]
ret += ret + 1
else:
cur = cur.get(1,{})
ret <<=1
return ret
class Solution:
def findMaximumXOR(self, nums: List[int]) -> int:
trie = TrieXor(nums)
return max(trie.find_max_xor_num(x) for x in nums)
链接: 1803. 统计异或值在范围内的数对有多少
01字典树
,如果我们能找出一堆数字中,有多少个数字异或x小于limit,标记为f(x,limit),class TrieXor:
def __init__(self,nums = None,bit_len=31):
# 01字典树,用来处理异或最值问题,本模板只处理数字最低的31位
# 用nums初始化字典树,可以为空
self.trie = {}
self.cnt = 0 # 字典树插入了几个值
if nums:
for a in nums:
self.insert(a)
self.bit_len = bit_len
def insert(self,num):
# 01字典树插入一个数字num,只会处理最低bit_len位。
cur = self.trie
for i in range(self.bit_len,-1,-1):
nxt = (num>>i)&1
if nxt not in cur:
cur[nxt] = {}
cur = cur[nxt]
cur[3] = cur.get(3,0)+1 # 这个节点被经过了几次
cur[5] = num # 记录这个数:'#'或者'end'等非01的当key都行;这里由于key只有01因此用5
self.cnt += 1
def find_max_xor_num(self,num):
# 计算01字典树里任意数字异或num的最大值,只会处理最低bit_len位。
# 贪心的从高位开始处理,显然num的某位是0,对应的优先应取1;相反同理
cur = self.trie
ret = 0
for i in range(self.bit_len,-1,-1):
if (num>>i)&1 == 0: # 如果本位是0,那么取1才最大;取不到1才取0
if 1 in cur:
cur = cur[1]
ret += ret + 1
else:
cur = cur.get(0,{})
ret <<= 1
else:
if 0 in cur:
cur = cur[0]
ret += ret + 1
else:
cur = cur.get(1,{})
ret <<=1
return ret
def count_less_than_limit_xor_num(self,num,limit):
# 计算01字典树里有多少数字异或num后小于limit
# 由于计算的是严格小于,因此只需要计算三种情况:
# 1.当limit对应位是1,且异或值为0的子树部分,全部贡献。
# 2.当limit对应位是1,且异或值为1的子树部分,向后检查。
# 3.当limit对应为是0,且异或值为0的子树部分,向后检查。
# 若向后检查取不到,直接剪枝break
cur = self.trie
ans = 0
for i in range(self.bit_len,-1,-1):
a,b = (num>>i)&1,(limit>>i)&1
if b == 1:
if a == 0:
if 0 in cur: # 右子树上所有值异或1都是0,一定小于1
ans += cur[0][3]
cur = cur.get(1) # 继续检查右子树
if not cur:break # 如果没有1,即没有右子树,可以直接跳出了
if a == 1:
if 1 in cur: # 右子树上所有值异或1都是0,一定小于1
ans += cur[1][3]
cur = cur.get(0) # 继续检查左子树
if not cur:break # 如果没有0,即没有左子树,可以直接跳出了
else:
cur = cur.get(a) # limit是0,因此只需要检查异或和为0的子树
if not cur:break # 如果没有相同边的子树,即等于0的子树,可以直接跳出了
return ans
class Solution:
def countPairs(self, nums: List[int], low: int, high: int) -> int:
trie = TrieXor(bit_len=15)
ans = 0
for x in nums:
ans += trie.count_less_than_limit_xor_num(x,high+1) - trie.count_less_than_limit_xor_num(x,low)
trie.insert(x)
return ans
class BinaryTrie:
"""
Reference:
- https://atcoder.jp/contests/arc028/submissions/19916627
- https://judge.yosupo.jp/submission/35057
"""
__slots__ = (
"max_log",
"x_end",
"v_list",
"multiset",
"add_query_count",
"add_query_limit",
"edges",
"size",
"is_end",
"max_v",
"lazy",
)
def __init__(
self,
max_log: int = 60, # 字典树最大深度,取决于值域
allow_multiple_elements: bool = True, # 是否允许重复数值
add_query_limit: int = 10**6, # 允许添加多少次数字
):
self.max_log = max_log # 最大深度
self.x_end = 1 << max_log # 这个深度下允许的值域最大值
self.v_list = [0] * (max_log + 1)
self.multiset = allow_multiple_elements
self.add_query_count = 0 # 计数添加了多少次,感觉没用
self.add_query_limit = add_query_limit # 最多允许这么多次访问,assert注释了
n = max_log * add_query_limit + 1
self.edges = [-1] * (2 * n)
self.size = [0] * n
self.is_end = [0] * n
self.max_v = 0
self.lazy = 0 # 用于整棵树异或
def add(self, x: int):
# assert 0 <= x < self.x_end
# assert 0 <= self.add_query_count < self.add_query_limit
x ^= self.lazy
v = 0
for i in range(self.max_log - 1, -1, -1):
d = (x >> i) % 2
if self.edges[2 * v + d] == -1:
self.max_v += 1
self.edges[2 * v + d] = self.max_v
v = self.edges[2 * v + d]
self.v_list[i] = v
if self.multiset or self.is_end[v] == 0:
self.is_end[v] += 1
for v in self.v_list:
self.size[v] += 1
self.add_query_count += 1
def discard(self, x: int):
# 移除一个x
if not 0 <= x < self.x_end:
return
x ^= self.lazy
v = 0
for i in range(self.max_log - 1, -1, -1):
d = (x >> i) % 2
if self.edges[2 * v + d] == -1:
return
v = self.edges[2 * v + d]
self.v_list[i] = v
if self.is_end[v] > 0:
self.is_end[v] -= 1
for v in self.v_list:
self.size[v] -= 1
def erase(self, x: int, count: int = -1):
# 移除count个x;如果count==-1,则全部移除
# assert -1 <= count
if not 0 <= x < self.x_end:
return
x ^= self.lazy
v = 0
for i in range(self.max_log - 1, -1, -1):
d = (x >> i) % 2
if self.edges[2 * v + d] == -1:
return
v = self.edges[2 * v + d]
self.v_list[i] = v
if count == -1 or self.is_end[v] < count:
count = self.is_end[v]
if self.is_end[v] > 0:
self.is_end[v] -= count
for v in self.v_list:
self.size[v] -= count
def count(self, x: int) -> int: #
# 检查有几个x
if not 0 <= x < self.x_end:
return 0
x ^= self.lazy
v = 0
for i in range(self.max_log - 1, -1, -1):
d = (x >> i) % 2
if self.edges[2 * v + d] == -1:
return 0
v = self.edges[2 * v + d]
return self.is_end[v]
def bisect_left(self, x: int) -> int:
# 找第一个大于等于x数的下标;也可以说是找有几个数比x小
if x < 0:
return 0
if self.x_end <= x:
return len(self)
v = 0
ret = 0
for i in range(self.max_log - 1, -1, -1):
d = (x >> i) % 2
l = (self.lazy >> i) % 2
lc = self.edges[2 * v]
rc = self.edges[2 * v + 1]
if l == 1:
lc, rc = rc, lc
if d:
if lc != -1:
ret += self.size[lc]
if rc == -1:
return ret
v = rc
else:
if lc == -1:
return ret
v = lc
return ret
def bisect_right(self, x: int) -> int:
# 找第一个比x大的数的下标
return self.bisect_left(x + 1)
def index(self, x: int) -> int:
if x not in self:
raise ValueError(f"{x} is not in BinaryTrie")
return self.bisect_left(x)
def find(self, x: int) -> int:
# 找x是否在树里,不在返回-1;在返回下标
if x not in self:
return -1
return self.bisect_left(x)
def kth_elem(self, k: int) -> int:
# 计算第k小的值,其中k是0-index
if k < 0:
k += self.size[0]
# assert 0 <= k < self.size[0]
v = 0
ret = 0
for i in range(self.max_log - 1, -1, -1):
l = (self.lazy >> i) % 2
lc = self.edges[2 * v]
rc = self.edges[2 * v + 1]
if l == 1:
lc, rc = rc, lc
if lc == -1:
v = rc
ret |= 1 << i
continue
if self.size[lc] <= k:
k -= self.size[lc]
v = rc
ret |= 1 << i
else:
v = lc
return ret
def minimum(self) -> int:
# 返回树里最小的值
return self.kth_elem(0)
def maximum(self) -> int:
# 返回树里最大的值
return self.kth_elem(-1)
def xor_all(self, x: int):
# 把整棵树上所有值都异或x
# assert 0 <= x < self.x_end
self.lazy ^= x
def __iter__(self):
q = [(0, 0)]
for i in range(self.max_log - 1, -1, -1):
l = (self.lazy >> i) % 2
nq = []
for v, x in q:
lc = self.edges[2 * v]
rc = self.edges[2 * v + 1]
if l == 1:
lc, rc = rc, lc
if lc != -1:
nq.append((lc, 2 * x))
if rc != -1:
nq.append((rc, 2 * x + 1))
q = nq
for v, x in q:
for _ in range(self.is_end[v]):
yield x
def __str__(self):
prefix = "BinaryTrie("
content = list(map(str, self))
suffix = ")"
if content:
content[0] = prefix + content[0]
content[-1] = content[-1] + suffix
else:
content = [prefix + suffix]
return ", ".join(content)
def __getitem__(self, k):
return self.kth_elem(k)
def __contains__(self, x: int) -> bool:
return bool(self.count(x))
def __len__(self):
return self.size[0]
def __bool__(self):
return bool(len(self))
def __ixor__(self, x: int):
self.xor_all(x)
return self
class Solution:
def countPairs(self, nums: List[int], low: int, high: int) -> int:
n = len(nums)
max_log = max(nums).bit_length()
bt = BinaryTrie(add_query_limit=n, max_log=max_log, allow_multiple_elements=True)
ans = 0
for num in nums:
bt.xor_all(num)
ans += bt.bisect_right(high) - bt.bisect_left(low)
bt.xor_all(num)
bt.add(num)
return ans
链接: 4869. 异或值
# Problem: 异或值
# Contest: AcWing
# URL: https://www.acwing.com/problem/content/4872/
# Memory Limit: 256 MB
# Time Limit: 1000 ms
import sys
RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda *x: sys.stderr.write(f'{str(x)}\n')
MOD = 10 ** 9 + 7
PROBLEM = """
"""
class TrieXor:
def __init__(self, nums=None, bit_len=31):
# 01字典树,用来处理异或最值问题,本模板只处理数字最低的31位
# 用nums初始化字典树,可以为空
self.trie = {}
self.cnt = 0 # 字典树插入了几个值
if nums:
for a in nums:
self.insert(a)
self.bit_len = bit_len
def insert(self, num):
# 01字典树插入一个数字num,只会处理最低bit_len位。
cur = self.trie
for i in range(self.bit_len - 1, -1, -1):
nxt = (num >> i) & 1
if nxt not in cur:
cur[nxt] = {}
cur = cur[nxt]
cur[3] = cur.get(3, 0) + 1 # 这个节点被经过了几次
cur[5] = num # 记录这个数:'#'或者'end'等非01的当key都行;这里由于key只有01因此用5
self.cnt += 1
def find_max_xor_num(self, num):
# 计算01字典树里任意数字异或num的最大值,只会处理最低bit_len位。
# 贪心的从高位开始处理,显然num的某位是0,对应的优先应取1;相反同理
cur = self.trie
ret = 0
for i in range(self.bit_len - 1, -1, -1):
if (num >> i) & 1 == 0: # 如果本位是0,那么取1才最大;取不到1才取0
if 1 in cur:
cur = cur[1]
ret += ret + 1
else:
cur = cur.get(0, {})
ret <<= 1
else:
if 0 in cur:
cur = cur[0]
ret += ret + 1
else:
cur = cur.get(1, {})
ret <<= 1
return ret
def find_max_xor_any(self):
"""计算所有数字异或异或同一数字x时,结果里max的最小值"""
def dfs(cur, bit): # 计算当前层以下能取到的最小的最大值
if bit < 0:
return 0
if 0 not in cur: # 如果这层都是1,那么可以使x的这层是1,结果里的这层就是0,递归下一层即可。
return dfs(cur[1], bit - 1)
elif 1 not in cur: # 如果这层都是0,使x这层是0,递归下一层。
return dfs(cur[0], bit - 1)
# 如果01都有,那么x这层不管是几,结果最大值里这层都是1,那么考虑走1还是走0方向,取min后加上本层的值。
return min(dfs(cur[0], bit - 1), dfs(cur[1], bit - 1)) + (1 << bit)
return dfs(self.trie, self.bit_len - 1)
def count_less_than_limit_xor_num(self, num, limit):
# 计算01字典树里有多少数字异或num后小于limit
# 由于计算的是严格小于,因此只需要计算三种情况:
# 1.当limit对应位是1,且异或值为0的子树部分,全部贡献。
# 2.当limit对应位是1,且异或值为1的子树部分,向后检查。
# 3.当limit对应为是0,且异或值为0的子树部分,向后检查。
# 若向后检查取不到,直接剪枝break
cur = self.trie
ans = 0
for i in range(self.bit_len - 1, -1, -1):
a, b = (num >> i) & 1, (limit >> i) & 1
if b == 1:
if a == 0:
if 0 in cur: # 右子树上所有值异或1都是0,一定小于1
ans += cur[0][3]
cur = cur.get(1) # 继续检查右子树
if not cur: break # 如果没有1,即没有右子树,可以直接跳出了
if a == 1:
if 1 in cur: # 右子树上所有值异或1都是0,一定小于1
ans += cur[1][3]
cur = cur.get(0) # 继续检查左子树
if not cur: break # 如果没有0,即没有左子树,可以直接跳出了
else:
cur = cur.get(a) # limit是0,因此只需要检查异或和为0的子树
if not cur: break # 如果没有相同边的子树,即等于0的子树,可以直接跳出了
return ans
# 封装成类卡常真是吐了 ms
def solve_tle():
n, = RI()
a = RILST()
trie = TrieXor(bit_len=30)
for x in a:
trie.insert(x)
ans = trie.find_max_xor_any()
print(ans)
# 7224 ms
def solve1():
n, = RI()
a = RILST()
trie = {}
for x in a:
cur = trie
for i in range(29, -1, -1):
nxt = (x >> i) & 1
if nxt not in cur:
cur[nxt] = {}
cur = cur[nxt]
def dfs(cur, bit): # 计算当前层以下能取到的最小的最大值
if bit < 0:
return 0
if 0 not in cur: # 如果这层都是1,那么可以使x的这层是1,结果里的这层就是0,递归下一层即可。
return dfs(cur[1], bit - 1)
elif 1 not in cur: # 如果这层都是0,使x这层是0,递归下一层。
return dfs(cur[0], bit - 1)
# 如果01都有,那么x这层不管是几,结果最大值里这层都是1,那么考虑走1还是走0方向,取min后加上本层的值。
return min(dfs(cur[0], bit - 1), dfs(cur[1], bit - 1)) + (1 << bit)
ans = dfs(trie, 29)
print(ans)
# 1725 ms
def solve():
n, = RI()
a = RILST()
def dfs(a, bit): # 计算当前层以下能取到的最小的最大值
if bit < 0:
return 0
x, y = [], []
t = 1 << bit
for v in a:
if v & t:
x.append(v)
else:
y.append(v)
if not x: return dfs(y, bit - 1)
if not y: return dfs(x, bit - 1)
# 如果01都有,那么x这层不管是几,结果最大值里这层都是1,那么考虑走1还是走0方向,取min后加上本层的值。
return min(dfs(x, bit - 1), dfs(y, bit - 1)) + t
print(dfs(a, 29))
if __name__ == '__main__':
solve()