[python刷题模板] 字典树

[python刷题模板] 字典树

    • 一、 算法&数据结构
      • 1. 描述
      • 2. 复杂度分析
      • 3. 常见应用
      • 4. 常用优化
    • 二、 模板代码
      • 0. 添加一个更容易写的字典树-用字典实现每一层。
      • -1. 再添加一个dict实现的字典树,记录每个字符出现次数。
      • 1. 带.的模糊匹配
      • 2. 前缀匹配
        • `字典树`树实现
        • `字典树`数组实现-动态开数组
        • `字典树`数组实现-预先开数组
      • 3. 01字典树-最大异或和
      • 4. 01字典树-限制异或和的数字数量。
      • 5. 01字典树-查询这些数字异或同一个数后的最大值的最小值。
    • 三、其他
    • 四、更多例题
    • 五、参考链接

一、 算法&数据结构

1. 描述

字典树通常用来处理字符串前缀查找问题,不过很多题都可以用cpp_map水过去

2. 复杂度分析

在n个字符串words[i]里,执行m次查找

  1. 单次查询query, O(w),w为word长度
  2. 单次插入update,O(w)
  3. 建树,O(nw)
  4. 总查找,最坏O(nmw),平均O(mw)

3. 常见应用

  1. 求前缀为s的字符串有多少
  2. 集合最大异或和
  3. 还有很多,我不会

4. 常用优化

  1. 发现python写起来太方便啦!直接用字典表示每一层,end的话加特殊key即可(保证key不在字典字符里)。
  2. 树表示比较好写。
  3. 数组表示可以去查查CPP板子

二、 模板代码

0. 添加一个更容易写的字典树-用字典实现每一层。

例题: 648. 单词替换
题意找出每个单词在字典中的最短前缀。
今天的每日一题,看别人代码发现有nb的实现方案,速度又快还节省空间,还代码短。

[python刷题模板] 字典树_第1张图片

class Solution:
    def replaceWords(self, dictionary: List[str], sentence: str) -> str:
        trie = {}
        for s in dictionary:
            pos = trie 
            for c in s:
                if c not in pos:
                    pos[c] = {}
                pos = pos[c]
            pos['is_end'] = s 
        
        def find(s):
            pos = trie
            for c in s:
                if c not in pos:
                    return s 
                pos = pos[c]
                if 'is_end' in pos:
                    return pos['is_end']
            return s
        
        return ' '.join(map(find,sentence.split()))        

-1. 再添加一个dict实现的字典树,记录每个字符出现次数。

例题: 6183. 字符串的前缀分数和
311周赛T4.只能说很水,但之前没有写记cnt的模板,因此比赛时候调了一会。
这题求每个字符串前缀在整个数组中出现过的总次数作为分数,因此find的时候累加cnt即可。
[LeetCode周赛复盘] 第 311 场周赛20220918

class Solution:
    def sumPrefixScores(self, words: List[str]) -> List[int]:
        trie = {}
        for s in words:
            pos = trie
            for c in s:
                if c not in pos:
                    pos[c] = {}
                    pos[c]['cnt'] = 0
                pos = pos[c]
                pos['cnt'] += 1
            pos['end'] = s
        # print(trie)
        
        ans = []
        def find(s):
            pos = trie
            ret = 0
            for c in s:
                if c not in pos:
                    return ret
                ret += pos[c]['cnt']
                pos = pos[c]
                # if 'end' in pos:
                #     return
            return ret
                
        for s in words:
            ans.append(find(s))
        return ans    

1. 带.的模糊匹配

例题: 211. 添加与搜索单词 - 数据结构设计
这题由于点代表任意字符,所以要dfs,写成了树

class TrieNode:
    def __init__(self,cnt=0):
        self.cnt = cnt 
        self.next = [None]*26
        self.is_end = False
    
class WordDictionary:
    def __init__(self):
        self.root = TrieNode()
        self.depth = 0

    def addWord(self, word: str) -> None:
        cur = self.root
        for c in word:
            i = ord(c)-ord('a')
            if not cur.next[i] :  # 没有这个字符
                cur.next[i] = TrieNode()
            cur = cur.next[i]
            cur.cnt += 1
        cur.is_end = True            

    def search(self, word: str) -> bool:
        def dfs(root,word,start):
            if not root:
                return False
            if start == len(word):
                return root.is_end

            c = word[start]
            if c == '.':
                for i in range(26):
                    if root.next[i] and root.next[i].cnt >0 and dfs(root.next[i],word,start+1):
                        return True
                return False

            else :
                i = ord(c)-ord('a')
                if not root.next[i] or root.next[i].cnt == 0:
                    return False
                return dfs(root.next[i],word,start+1)
            
        return dfs(self.root,word,0)

2. 前缀匹配

例题: 1268. 搜索推荐系统

字典树树实现

经典案例

class TrieNode:
    def __init__(self,cnt=0):
        self.cnt = cnt 
        self.next = [None]*26
        self.is_end = False
    def insert(self, word: str) -> None:
        cur = self
        for c in word:
            i = ord(c)-ord('a')
            if not cur.next[i] :  # 没有这个字符
                cur.next[i] = TrieNode()
            cur = cur.next[i]
            cur.cnt += 1
        cur.is_end = True
    def find(self,word):
        ans = []
        cur = self
        now_word=''
        for c in  word:
            now_word+=c
            i = ord(c)-ord('a')
            if not cur.next[i]:
                return ans
            cur = cur.next[i]
        
        def dfs(root,now_word):
            if len(ans) >= 3:
                return ans
            if root.is_end:
                ans.append(now_word)
            for i in range(26):
                c = chr(ord('a')+i)                
                if root.next[i]: 
                    dfs(root.next[i],now_word+c)
            
        dfs(cur,now_word)
        return ans
        
class Solution:
    def suggestedProducts(self, products: List[str], searchWord: str) -> List[List[str]]:
        trie = TrieNode()
        for product in products:
            trie.insert(product)            

        ans = []
        for i in range(len(searchWord)):
            ans.append(trie.find(searchWord[:i+1]))
        
        return ans

字典树数组实现-动态开数组

提交后竟然没有树快

class Trie:
    def __init__(self,c_cnt=26):
        self.c_cnt=c_cnt
        self.trie = [[0]*c_cnt]
        self.is_end = [False]
        self.w = [0]  # 这个节点被经过了几次
    def insert(self,word):
        trie = self.trie
        u = 0
        for c in word:
            i = ord(c)-ord('a')
            if trie[u][i] == 0:
                trie.append([0]*self.c_cnt)
                trie[u][i] = len(trie)-1
                self.is_end.append(False)
                self.w.append(0)
            u = trie[u][i]
            self.w[u] += 1
        self.is_end[u] = True

    def find(self,word):
        # 查找这个前缀的字符串有几个
        trie = self.trie
        u = 0
        for c in word:
            i = ord(c)-ord('a')
            if trie[u][i] == 0:
                return 0
            u = trie[u][i]
        return self.w[u]
    def find_ans(self,word):
        trie = self.trie
        ans = []
        u = 0
        now_word=''
        for c in  word:
            now_word+=c
            i = ord(c)-ord('a')
            if trie[u][i] == 0:
                return ans
            u = trie[u][i]
        
        def dfs(u,now_word):
            if len(ans) >= 3:
                return ans
            if self.is_end[u]:
                ans.append(now_word)
            for i in range(26):
                c = chr(ord('a')+i)                
                if trie[u][i] > 0: 
                    dfs(trie[u][i],now_word+c)
            
        dfs(u,now_word)
        return ans        
            
class Solution:
    def suggestedProducts(self, products: List[str], searchWord: str) -> List[List[str]]:        
        trie = Trie()
        for product in products:
            trie.insert(product)            

        ans = []
        for i in range(len(searchWord)):
            ans.append(trie.find_ans(searchWord[:i+1]))
        
        return ans

字典树数组实现-预先开数组

数组开辟长度应该=len*m,len是单词数,m是最长单词长度,提交后竟然没有动态开数组快

class Trie:
    def __init__(self,size,c_cnt=26):
        self.c_cnt=c_cnt
        self.trie = [[0]*c_cnt for _ in range(size)]
        self.is_end = [False] * (size)
        self.w = [0] * (size)  # 这个节点被经过了几次
        self.total_cnt = 1                            
    def insert(self,word):
        trie = self.trie
        u = 0
        for c in word:
            i = ord(c)-ord('a')
            if trie[u][i] == 0:
                trie[u][i] = self.total_cnt
                self.total_cnt += 1
            u = trie[u][i]
            self.w[u] += 1
        self.is_end[u] = True

    def find(self,word):
        # 查找这个前缀的字符串有几个
        trie = self.trie
        u = 0
        for c in word:
            i = ord(c)-ord('a')
            if trie[u][i] == 0:
                return 0
            u = trie[u][i]
        return self.w[u]
    def find_ans(self,word):
        trie = self.trie
        ans = []
        u = 0
        now_word=''
        for c in  word:
            now_word+=c
            i = ord(c)-ord('a')
            if trie[u][i] == 0:
                return ans
            u = trie[u][i]
        
        def dfs(u,now_word):
            if len(ans) >= 3:
                return ans
            if self.is_end[u]:
                ans.append(now_word)
            for i in range(26):
                c = chr(ord('a')+i)                
                if trie[u][i] > 0: 
                    dfs(trie[u][i],now_word+c)
            
        dfs(u,now_word)
        return ans                
class Solution:
    def suggestedProducts(self, products: List[str], searchWord: str) -> List[List[str]]:
        #     ans.append(trie.find(searchWord[:i+1]))
        max_len = max([len(s) for s in products])
        trie = Trie(max_len*len(products)+1)
        for product in products:
            trie.insert(product)            

        ans = []
        for i in range(len(searchWord)):
            ans.append(trie.find_ans(searchWord[:i+1]))
        
        return ans

3. 01字典树-最大异或和

链接: 421. 数组中两个数的最大异或值

01字典树经典案例,贪心的去找子节点


class TrieXor:
    def __init__(self,nums = None):
        # 01字典树,用来处理异或最值问题,本模板只处理数字最低的31位
        # 用nums初始化字典树,可以为空
        self.tree = {}
        if nums:
            for a in nums:
                self.insert(a)


    def insert(self,num):
        # 01字典树插入一个数字num,只会处理最低31位。
        cur = self.tree 
        for i in range(31,-1,-1):
            nxt = (num>>i)&1
            if nxt not in cur:
                cur[nxt] = {}
            cur = cur[nxt]
      


    def find_max_xor_num(self,num):
        # 计算01字典树里任意数字异或num的最大值,只会处理最低31位。
        # 贪心的从高位开始处理,显然num的某位是0,对应的优先应取1;相反同理
        cur = self.tree
        ret = 0
        for i in range(31,-1,-1):            
            if (num>>i)&1 == 0:  # 如果本位是0,那么取1才最大;取不到1才取0
                if 1 in cur:
                    cur = cur[1]
                    ret += ret + 1
                else:
                    cur = cur.get(0,{})
                    ret <<= 1
            else:
                if 0 in cur:
                    cur = cur[0]
                    ret += ret + 1
                else:
                    cur = cur.get(1,{})
                    ret <<=1
        return ret

class Solution:
    def findMaximumXOR(self, nums: List[int]) -> int:
        trie = TrieXor(nums)    
        return max(trie.find_max_xor_num(x) for x in nums)

4. 01字典树-限制异或和的数字数量。

链接: 1803. 统计异或值在范围内的数对有多少

  • 01字典树,如果我们能找出一堆数字中,有多少个数字异或x小于limit,标记为f(x,limit),
  • 那么问题就转化为求f(x,high+1)-f(x,low) ,其中x随着数字遍历,且遍历完加入字典树。
class TrieXor:
    def __init__(self,nums = None,bit_len=31):
        # 01字典树,用来处理异或最值问题,本模板只处理数字最低的31位
        # 用nums初始化字典树,可以为空
        self.trie = {}
        self.cnt = 0  # 字典树插入了几个值
        if nums:
            for a in nums:
                self.insert(a)
        self.bit_len = bit_len


    def insert(self,num):
        # 01字典树插入一个数字num,只会处理最低bit_len位。
        cur = self.trie 
        for i in range(self.bit_len,-1,-1):
            nxt = (num>>i)&1
            if nxt not in cur:
                cur[nxt] = {}
            cur = cur[nxt]
            cur[3] = cur.get(3,0)+1  # 这个节点被经过了几次
        cur[5] = num   # 记录这个数:'#'或者'end'等非01的当key都行;这里由于key只有01因此用5
        self.cnt += 1


    def find_max_xor_num(self,num):
        # 计算01字典树里任意数字异或num的最大值,只会处理最低bit_len位。
        # 贪心的从高位开始处理,显然num的某位是0,对应的优先应取1;相反同理
        cur = self.trie
        ret = 0
        for i in range(self.bit_len,-1,-1):            
            if (num>>i)&1 == 0:  # 如果本位是0,那么取1才最大;取不到1才取0
                if 1 in cur:
                    cur = cur[1]
                    ret += ret + 1
                else:
                    cur = cur.get(0,{})
                    ret <<= 1
            else:
                if 0 in cur:
                    cur = cur[0]
                    ret += ret + 1
                else:
                    cur = cur.get(1,{})
                    ret <<=1
        return ret
    
    def count_less_than_limit_xor_num(self,num,limit):
        # 计算01字典树里有多少数字异或num后小于limit
        # 由于计算的是严格小于,因此只需要计算三种情况:
        # 1.当limit对应位是1,且异或值为0的子树部分,全部贡献。
        # 2.当limit对应位是1,且异或值为1的子树部分,向后检查。
        # 3.当limit对应为是0,且异或值为0的子树部分,向后检查。
        # 若向后检查取不到,直接剪枝break
        cur = self.trie
        ans = 0
        for i in range(self.bit_len,-1,-1):
            a,b = (num>>i)&1,(limit>>i)&1
            if b == 1:
                if a == 0:
                    if 0 in cur:  # 右子树上所有值异或1都是0,一定小于1
                        ans += cur[0][3]                                     
                    cur = cur.get(1)  # 继续检查右子树
                    if not cur:break  # 如果没有1,即没有右子树,可以直接跳出了                    
                if a == 1:
                    if 1 in cur:  # 右子树上所有值异或1都是0,一定小于1
                        ans += cur[1][3]
                    cur = cur.get(0)  # 继续检查左子树
                    if not cur:break  # 如果没有0,即没有左子树,可以直接跳出了   
            else:
                cur = cur.get(a)  # limit是0,因此只需要检查异或和为0的子树
                if not cur:break  # 如果没有相同边的子树,即等于0的子树,可以直接跳出了    
        return ans


class Solution:
    def countPairs(self, nums: List[int], low: int, high: int) -> int:
        trie = TrieXor(bit_len=15)
        ans = 0
        for x in nums:
            ans += trie.count_less_than_limit_xor_num(x,high+1) - trie.count_less_than_limit_xor_num(x,low)
            trie.insert(x)
        return ans
  • 附一个抄的板子,没详细测试,自己写了几句注释
class BinaryTrie:
    """
    Reference:
     - https://atcoder.jp/contests/arc028/submissions/19916627
     - https://judge.yosupo.jp/submission/35057
    """

    __slots__ = (
        "max_log",
        "x_end",
        "v_list",
        "multiset",
        "add_query_count",
        "add_query_limit",
        "edges",
        "size",
        "is_end",
        "max_v",
        "lazy",
    )

    def __init__(
        self,
        max_log: int = 60,  # 字典树最大深度,取决于值域
        allow_multiple_elements: bool = True,  # 是否允许重复数值
        add_query_limit: int = 10**6,  # 允许添加多少次数字
    ):
        self.max_log = max_log  # 最大深度
        self.x_end = 1 << max_log  # 这个深度下允许的值域最大值
        self.v_list = [0] * (max_log + 1)  
        self.multiset = allow_multiple_elements
        self.add_query_count = 0  # 计数添加了多少次,感觉没用
        self.add_query_limit = add_query_limit  # 最多允许这么多次访问,assert注释了
        n = max_log * add_query_limit + 1
        self.edges = [-1] * (2 * n)
        self.size = [0] * n
        self.is_end = [0] * n
        self.max_v = 0
        self.lazy = 0  # 用于整棵树异或

    def add(self, x: int):
        # assert 0 <= x < self.x_end
        # assert 0 <= self.add_query_count < self.add_query_limit
        x ^= self.lazy
        v = 0
        for i in range(self.max_log - 1, -1, -1):
            d = (x >> i) % 2
            if self.edges[2 * v + d] == -1:
                self.max_v += 1
                self.edges[2 * v + d] = self.max_v
            v = self.edges[2 * v + d]
            self.v_list[i] = v
        if self.multiset or self.is_end[v] == 0:
            self.is_end[v] += 1
            for v in self.v_list:
                self.size[v] += 1
        self.add_query_count += 1

    def discard(self, x: int):
        # 移除一个x
        if not 0 <= x < self.x_end:
            return
        x ^= self.lazy
        v = 0
        for i in range(self.max_log - 1, -1, -1):
            d = (x >> i) % 2
            if self.edges[2 * v + d] == -1:
                return
            v = self.edges[2 * v + d]
            self.v_list[i] = v
        if self.is_end[v] > 0:
            self.is_end[v] -= 1
            for v in self.v_list:
                self.size[v] -= 1

    def erase(self, x: int, count: int = -1):
        # 移除count个x;如果count==-1,则全部移除
        # assert -1 <= count
        if not 0 <= x < self.x_end:
            return
        x ^= self.lazy
        v = 0
        for i in range(self.max_log - 1, -1, -1):
            d = (x >> i) % 2
            if self.edges[2 * v + d] == -1:
                return
            v = self.edges[2 * v + d]
            self.v_list[i] = v
        if count == -1 or self.is_end[v] < count:
            count = self.is_end[v]
        if self.is_end[v] > 0:
            self.is_end[v] -= count
            for v in self.v_list:
                self.size[v] -= count

    def count(self, x: int) -> int:  #
        # 检查有几个x
        if not 0 <= x < self.x_end:
            return 0
        x ^= self.lazy
        v = 0
        for i in range(self.max_log - 1, -1, -1):
            d = (x >> i) % 2
            if self.edges[2 * v + d] == -1:
                return 0
            v = self.edges[2 * v + d]
        return self.is_end[v]

    def bisect_left(self, x: int) -> int:
        # 找第一个大于等于x数的下标;也可以说是找有几个数比x小
        if x < 0:
            return 0
        if self.x_end <= x:
            return len(self)
        v = 0
        ret = 0
        for i in range(self.max_log - 1, -1, -1):
            d = (x >> i) % 2
            l = (self.lazy >> i) % 2
            lc = self.edges[2 * v]
            rc = self.edges[2 * v + 1]
            if l == 1:
                lc, rc = rc, lc
            if d:
                if lc != -1:
                    ret += self.size[lc]
                if rc == -1:
                    return ret
                v = rc
            else:
                if lc == -1:
                    return ret
                v = lc
        return ret

    def bisect_right(self, x: int) -> int:
        # 找第一个比x大的数的下标
        return self.bisect_left(x + 1)

    def index(self, x: int) -> int:
        if x not in self:
            raise ValueError(f"{x} is not in BinaryTrie")
        return self.bisect_left(x)

    def find(self, x: int) -> int:
        # 找x是否在树里,不在返回-1;在返回下标
        if x not in self:
            return -1
        return self.bisect_left(x)

    def kth_elem(self, k: int) -> int:
        # 计算第k小的值,其中k是0-index
        if k < 0:
            k += self.size[0]
        # assert 0 <= k < self.size[0]
        v = 0
        ret = 0
        for i in range(self.max_log - 1, -1, -1):
            l = (self.lazy >> i) % 2
            lc = self.edges[2 * v]
            rc = self.edges[2 * v + 1]
            if l == 1:
                lc, rc = rc, lc
            if lc == -1:
                v = rc
                ret |= 1 << i
                continue
            if self.size[lc] <= k:
                k -= self.size[lc]
                v = rc
                ret |= 1 << i
            else:
                v = lc
        return ret

    def minimum(self) -> int:
        # 返回树里最小的值
        return self.kth_elem(0)

    def maximum(self) -> int:
        # 返回树里最大的值
        return self.kth_elem(-1)

    def xor_all(self, x: int):
        # 把整棵树上所有值都异或x
        # assert 0 <= x < self.x_end
        self.lazy ^= x

    def __iter__(self):
        q = [(0, 0)]
        for i in range(self.max_log - 1, -1, -1):
            l = (self.lazy >> i) % 2
            nq = []
            for v, x in q:
                lc = self.edges[2 * v]
                rc = self.edges[2 * v + 1]
                if l == 1:
                    lc, rc = rc, lc
                if lc != -1:
                    nq.append((lc, 2 * x))
                if rc != -1:
                    nq.append((rc, 2 * x + 1))
            q = nq
        for v, x in q:
            for _ in range(self.is_end[v]):
                yield x

    def __str__(self):
        prefix = "BinaryTrie("
        content = list(map(str, self))
        suffix = ")"
        if content:
            content[0] = prefix + content[0]
            content[-1] = content[-1] + suffix
        else:
            content = [prefix + suffix]
        return ", ".join(content)

    def __getitem__(self, k):
        return self.kth_elem(k)

    def __contains__(self, x: int) -> bool:
        return bool(self.count(x))

    def __len__(self):
        return self.size[0]

    def __bool__(self):
        return bool(len(self))

    def __ixor__(self, x: int):
        self.xor_all(x)
        return self



class Solution:
        def countPairs(self, nums: List[int], low: int, high: int) -> int:
            n = len(nums)
            max_log = max(nums).bit_length()
            bt = BinaryTrie(add_query_limit=n, max_log=max_log, allow_multiple_elements=True)
            ans = 0
            for num in nums:
                bt.xor_all(num)
                ans += bt.bisect_right(high) - bt.bisect_left(low)
                bt.xor_all(num)
                bt.add(num)
            return ans        

5. 01字典树-查询这些数字异或同一个数后的最大值的最小值。

链接: 4869. 异或值

  • 不用字典树直接递归也可以而且表现更优。
  • 按位从高到低考虑即可,但要递归。因此01字典树建树考虑更形象。
  • 实际做的时候封装后的类TLE了,拆出来才过。
# Problem: 异或值
# Contest: AcWing
# URL: https://www.acwing.com/problem/content/4872/
# Memory Limit: 256 MB
# Time Limit: 1000 ms

import sys

RI = lambda: map(int, sys.stdin.buffer.readline().split())
RS = lambda: map(bytes.decode, sys.stdin.buffer.readline().strip().split())
RILST = lambda: list(RI())
DEBUG = lambda *x: sys.stderr.write(f'{str(x)}\n')

MOD = 10 ** 9 + 7
PROBLEM = """
"""


class TrieXor:
    def __init__(self, nums=None, bit_len=31):
        # 01字典树,用来处理异或最值问题,本模板只处理数字最低的31位
        # 用nums初始化字典树,可以为空
        self.trie = {}
        self.cnt = 0  # 字典树插入了几个值
        if nums:
            for a in nums:
                self.insert(a)
        self.bit_len = bit_len

    def insert(self, num):
        # 01字典树插入一个数字num,只会处理最低bit_len位。
        cur = self.trie
        for i in range(self.bit_len - 1, -1, -1):
            nxt = (num >> i) & 1
            if nxt not in cur:
                cur[nxt] = {}
            cur = cur[nxt]
            cur[3] = cur.get(3, 0) + 1  # 这个节点被经过了几次
        cur[5] = num  # 记录这个数:'#'或者'end'等非01的当key都行;这里由于key只有01因此用5
        self.cnt += 1

    def find_max_xor_num(self, num):
        # 计算01字典树里任意数字异或num的最大值,只会处理最低bit_len位。
        # 贪心的从高位开始处理,显然num的某位是0,对应的优先应取1;相反同理
        cur = self.trie
        ret = 0
        for i in range(self.bit_len - 1, -1, -1):
            if (num >> i) & 1 == 0:  # 如果本位是0,那么取1才最大;取不到1才取0
                if 1 in cur:
                    cur = cur[1]
                    ret += ret + 1
                else:
                    cur = cur.get(0, {})
                    ret <<= 1
            else:
                if 0 in cur:
                    cur = cur[0]
                    ret += ret + 1
                else:
                    cur = cur.get(1, {})
                    ret <<= 1
        return ret

    def find_max_xor_any(self):
        """计算所有数字异或异或同一数字x时,结果里max的最小值"""

        def dfs(cur, bit):  # 计算当前层以下能取到的最小的最大值
            if bit < 0:
                return 0
            if 0 not in cur:  # 如果这层都是1,那么可以使x的这层是1,结果里的这层就是0,递归下一层即可。
                return dfs(cur[1], bit - 1)
            elif 1 not in cur:  # 如果这层都是0,使x这层是0,递归下一层。
                return dfs(cur[0], bit - 1)
            # 如果01都有,那么x这层不管是几,结果最大值里这层都是1,那么考虑走1还是走0方向,取min后加上本层的值。
            return min(dfs(cur[0], bit - 1), dfs(cur[1], bit - 1)) + (1 << bit)

        return dfs(self.trie, self.bit_len - 1)

    def count_less_than_limit_xor_num(self, num, limit):
        # 计算01字典树里有多少数字异或num后小于limit
        # 由于计算的是严格小于,因此只需要计算三种情况:
        # 1.当limit对应位是1,且异或值为0的子树部分,全部贡献。
        # 2.当limit对应位是1,且异或值为1的子树部分,向后检查。
        # 3.当limit对应为是0,且异或值为0的子树部分,向后检查。
        # 若向后检查取不到,直接剪枝break
        cur = self.trie
        ans = 0
        for i in range(self.bit_len - 1, -1, -1):
            a, b = (num >> i) & 1, (limit >> i) & 1
            if b == 1:
                if a == 0:
                    if 0 in cur:  # 右子树上所有值异或1都是0,一定小于1
                        ans += cur[0][3]
                    cur = cur.get(1)  # 继续检查右子树
                    if not cur: break  # 如果没有1,即没有右子树,可以直接跳出了
                if a == 1:
                    if 1 in cur:  # 右子树上所有值异或1都是0,一定小于1
                        ans += cur[1][3]
                    cur = cur.get(0)  # 继续检查左子树
                    if not cur: break  # 如果没有0,即没有左子树,可以直接跳出了
            else:
                cur = cur.get(a)  # limit是0,因此只需要检查异或和为0的子树
                if not cur: break  # 如果没有相同边的子树,即等于0的子树,可以直接跳出了
        return ans


#    封装成类卡常真是吐了   ms
def solve_tle():
    n, = RI()
    a = RILST()
    trie = TrieXor(bit_len=30)
    for x in a:
        trie.insert(x)
    ans = trie.find_max_xor_any()
    print(ans)


#   7224    ms
def solve1():
    n, = RI()
    a = RILST()
    trie = {}
    for x in a:
        cur = trie
        for i in range(29, -1, -1):
            nxt = (x >> i) & 1
            if nxt not in cur:
                cur[nxt] = {}
            cur = cur[nxt]

    def dfs(cur, bit):  # 计算当前层以下能取到的最小的最大值
        if bit < 0:
            return 0
        if 0 not in cur:  # 如果这层都是1,那么可以使x的这层是1,结果里的这层就是0,递归下一层即可。
            return dfs(cur[1], bit - 1)
        elif 1 not in cur:  # 如果这层都是0,使x这层是0,递归下一层。
            return dfs(cur[0], bit - 1)
        # 如果01都有,那么x这层不管是几,结果最大值里这层都是1,那么考虑走1还是走0方向,取min后加上本层的值。
        return min(dfs(cur[0], bit - 1), dfs(cur[1], bit - 1)) + (1 << bit)

    ans = dfs(trie, 29)
    print(ans)


#     1725   ms
def solve():
    n, = RI()
    a = RILST()

    def dfs(a, bit):  # 计算当前层以下能取到的最小的最大值
        if bit < 0:
            return 0
        x, y = [], []
        t = 1 << bit
        for v in a:
            if v & t:
                x.append(v)
            else:
                y.append(v)
        if not x: return dfs(y, bit - 1)
        if not y: return dfs(x, bit - 1)
        # 如果01都有,那么x这层不管是几,结果最大值里这层都是1,那么考虑走1还是走0方向,取min后加上本层的值。
        return min(dfs(x, bit - 1), dfs(y, bit - 1)) + t

    print(dfs(a, 29))


if __name__ == '__main__':
    solve()

三、其他

  1. 字典树也有很多用法,但是我不太会:
    1、前缀查询
    2、最短前缀表示
    3、删除字符串
    4、删除前缀
    5、集合前缀
    6、离线算法
    7、模糊匹配
    8、集合最大异或
    9、树的异或最长路。

四、更多例题

  • 1707. 与数组中元素的最大异或值,可以离线,套用第三题模板
  • 1032. 字符流,逆序建树。

五、参考链接

  • 链接: 夜深人静写算法(七)- 字典树

你可能感兴趣的:(python刷题模板,python,深度优先,leetcode,算法,数据结构)