Trie树

一、定义

Trie树,又称为单词查找树,是一种树形结构(Trie一词源于单词Retrieval-取出)。
Trie树经常被搜索引擎系统用于文本词频统计。它的特点是:
利用字符串的公共前缀来减少查询时间,最大限度地减少无谓的字符串比较。

  • 查找命中所需的时间与被查找的键的长度成正比;
  • 查找未命中只需检查若干个字符;
Trie树_第1张图片
1-1-1 Trie树示意图

1.1 数据结构定义

Trie树的实现方式有很多种,本文中的实现称为“R向单词查找树”(R为字母表大小):

  1. Trie树的根结点不保存字符(也可看成保存空字符"");
  2. Trie树的每个结点含有R条链接(R为字母表的大小),每个结点R.children[i]指向以字符i为根结点的子树;
  3. 每个键所关联的值保存在该键的最后一个字符所在的结点中(值为空的结点在Trie树中没有对应的键)。
public class TrieST {
    private static final int R = 256;   // extended ASCII
    private Node root;              // root of trie
    private int n;                  // number of keys in trie
    // R-way trie node
    private class Node {
        private V val;
        private Node[] children = new Node[R];
    }
}
Trie树_第2张图片
1-2 Trie树的表示(R=26)

1.2 API定义

Trie树_第3张图片
1-2 Trie树的API定义

二、实现

2.1 查找

Trie树的查找步骤如下:

  1. 从根结点开始一次搜索;
  2. 取得要查找关键词的第一个字母,并根据该字母选择对应的子树并转到该子树继续进行检索;
  3. 在相应的子树上,取得要查找关键词的第二个字母,并进一步选择对应的子树进行检索。
  4. 迭代过程……
  5. 在某个结点处,关键词的所有字母已被取出,则读取附在该结点上的信息,即完成查找。

查找结果共3种情况:

  1. 键的尾字符对应的结点中保存的值为空;(未命中)
  2. 键的尾字符对应的结点中保存的值非空;(命中)
  3. 查找结束于一条空链接。(未命中)
Trie树_第4张图片
2-1 Trie树的查找示意图

查找-源码实现:

public V get(String key) {
    if (key == null) 
        throw new IllegalArgumentException("argument to get() is null");
    Node x = get(root, key, 0);
    if (x == null) 
        return null;
    return  x.val;
}
//在以x为根结点的Trie树中,查找键key.charAt(d)所在的结点
private Node get(Node x, String key, int d) {
    if (x == null) 
        return null;
    if (d == key.length()) 
        return x;
    char c = key.charAt(d);
    return get(x.children[c], key, d+1);
}

2.2 插入

在插入之前要进行一次查找,在Trie树中意味着沿着被查找的键的所有字符到达树中表示尾字符的结点或一个空链接。
结果共2种情况:

  1. 在到达键的尾字符之前就遇到了一个空链接;
  2. 在遇到空链接之前就到达了键的尾字符。

插入-源码实现:

public void put(String key, Value val) {
    if (key == null)
        throw new IllegalArgumentException("first argument to put() is null");
    root = put(root, key, val, 0);
}
//在以x为根结点的Trie树中,插入键key.charAt(d)所在的结点
//返回插入后新树的根结点
private Node put(Node x, String key, Value val, int d) {
    if (x == null)
        x = new Node();
    if (d == key.length()) {
        if (x.val == null)
            n++;
        x.val = val;
        return x;
    }
    char c = key.charAt(d);
    x.children[c] = put(x.children[c], key, val, d + 1);
    return x;
}

2.3 删除

删除一个键的流程如下:

  1. 查找键所在结点,并将值置为null;
  2. 判断该结点是否含有指向子结点的非空链接?
    如果有,则直接返回;
    如果没有,则删除该结点。若删除后,其父结点的所有链接也为空,就继续删除它的父结点,依此类推。

注:在递归删除了某个结点x之后,如果该结点的值和所有的链接均为空,则返回null,否则返回x。

Trie树_第5张图片
2-3 Trie树的删除示意图

删除-源码实现:

public void delete(String key) {
    if (key == null) 
        throw new IllegalArgumentException("argument to delete() is null");
    root = delete(root, key, 0);
}
//删除以x为根结点的树中的指定键,返回调整后的新树的根结点
private Node delete(Node x, String key, int d) {
    if (x == null) return null;
    if (d == key.length()) {
        if (x.val != null) n--;
        x.val = null;
    }
    else {
        char c = key.charAt(d);
        x.next[c] = delete(x.next[c], key, d+1);
    }
    // remove subtrie rooted at x if it is completely empty
    if (x.val != null) return x;
    for (int c = 0; c < R; c++)
        if (x.next[c] != null)
            return x;
    return null;
}

2.4 遍历

遍历得到树中的所有键。
注:根结点相当于保存空字符 ""。

Trie树_第6张图片
2-4 Trie树的遍历

遍历-源码实现:

public Iterable keys() {
    return keysWithPrefix("");
}
//查找所有以@prefix为前缀的键
public Iterable keysWithPrefix(String prefix) {
    Queue results = new Queue();
    //查找键prefix所在的结点x
    Node x = get(root, prefix, 0);
    //在以结点x为根,查找符合前缀的键                                                                                                                                                                                                                          
    collect(x, new StringBuilder(prefix), results);
    return results;
}
//在以结点x为根的子树中,查找键prefix
//注:prefix包含了所有从root到x的字符                                                                                                                                                                        
private void collect(Node x, StringBuilder prefix, Queue results) {
    if (x == null) return;
    if (x.val != null) 
        results.enqueue(prefix.toString());
    for (char c = 0; c < R; c++) {
        prefix.append(c);
        collect(x.next[c], prefix, results);
        prefix.deleteCharAt(prefix.length() - 1);
    }
}

2.5 完整源码

public class TrieST {
    private static final int R = 256;        // extended ASCII
    private Node root;      // root of trie
    private int n;          // number of keys in trie

    // R-way trie node
    private static class Node {
        private Object val;
        private Node[] next = new Node[R];
    }

    public TrieST() {
    }

    /**
     * Returns the value associated with the given key.
     * @param key the key
     * @return the value associated with the given key if the key is in the symbol table
     *     and {@code null} if the key is not in the symbol table
     * @throws IllegalArgumentException if {@code key} is {@code null}
     */
    public Value get(String key) {
        if (key == null) throw new IllegalArgumentException("argument to get() is null");
        Node x = get(root, key, 0);
        if (x == null) return null;
        return (Value) x.val;
    }

    /**
     * Does this symbol table contain the given key?
     * @param key the key
     * @return {@code true} if this symbol table contains {@code key} and
     *     {@code false} otherwise
     * @throws IllegalArgumentException if {@code key} is {@code null}
     */
    public boolean contains(String key) {
        if (key == null) throw new IllegalArgumentException("argument to contains() is null");
        return get(key) != null;
    }

    private Node get(Node x, String key, int d) {
        if (x == null) return null;
        if (d == key.length()) return x;
        char c = key.charAt(d);
        return get(x.next[c], key, d+1);
    }

    /**
     * Inserts the key-value pair into the symbol table, overwriting the old value
     * with the new value if the key is already in the symbol table.
     * If the value is {@code null}, this effectively deletes the key from the symbol table.
     * @param key the key
     * @param val the value
     * @throws IllegalArgumentException if {@code key} is {@code null}
     */
    public void put(String key, Value val) {
        if (key == null) throw new IllegalArgumentException("first argument to put() is null");
        if (val == null) delete(key);
        else root = put(root, key, val, 0);
    }

    private Node put(Node x, String key, Value val, int d) {
        if (x == null) x = new Node();
        if (d == key.length()) {
            if (x.val == null) n++;
            x.val = val;
            return x;
        }
        char c = key.charAt(d);
        x.next[c] = put(x.next[c], key, val, d+1);
        return x;
    }

    /**
     * Returns the number of key-value pairs in this symbol table.
     * @return the number of key-value pairs in this symbol table
     */
    public int size() {
        return n;
    }

    /**
     * Is this symbol table empty?
     * @return {@code true} if this symbol table is empty and {@code false} otherwise
     */
    public boolean isEmpty() {
        return size() == 0;
    }

    /**
     * Returns all keys in the symbol table as an {@code Iterable}.
     * To iterate over all of the keys in the symbol table named {@code st},
     * use the foreach notation: {@code for (Key key : st.keys())}.
     * @return all keys in the symbol table as an {@code Iterable}
     */
    public Iterable keys() {
        return keysWithPrefix("");
    }

    /**
     * Returns all of the keys in the set that start with {@code prefix}.
     * @param prefix the prefix
     * @return all of the keys in the set that start with {@code prefix},
     *     as an iterable
     */
    public Iterable keysWithPrefix(String prefix) {
        Queue results = new Queue();
        Node x = get(root, prefix, 0);
        collect(x, new StringBuilder(prefix), results);
        return results;
    }

    private void collect(Node x, StringBuilder prefix, Queue results) {
        if (x == null) return;
        if (x.val != null) results.enqueue(prefix.toString());
        for (char c = 0; c < R; c++) {
            prefix.append(c);
            collect(x.next[c], prefix, results);
            prefix.deleteCharAt(prefix.length() - 1);
        }
    }

    /**
     * Returns all of the keys in the symbol table that match {@code pattern},
     * where . symbol is treated as a wildcard character.
     * @param pattern the pattern
     * @return all of the keys in the symbol table that match {@code pattern},
     *     as an iterable, where . is treated as a wildcard character.
     */
    public Iterable keysThatMatch(String pattern) {
        Queue results = new Queue();
        collect(root, new StringBuilder(), pattern, results);
        return results;
    }

    private void collect(Node x, StringBuilder prefix, String pattern, Queue results) {
        if (x == null) return;
        int d = prefix.length();
        if (d == pattern.length() && x.val != null)
            results.enqueue(prefix.toString());
        if (d == pattern.length())
            return;
        char c = pattern.charAt(d);
        if (c == '.') {
            for (char ch = 0; ch < R; ch++) {
                prefix.append(ch);
                collect(x.next[ch], prefix, pattern, results);
                prefix.deleteCharAt(prefix.length() - 1);
            }
        }
        else {
            prefix.append(c);
            collect(x.next[c], prefix, pattern, results);
            prefix.deleteCharAt(prefix.length() - 1);
        }
    }

    /**
     * Returns the string in the symbol table that is the longest prefix of {@code query},
     * or {@code null}, if no such string.
     * @param query the query string
     * @return the string in the symbol table that is the longest prefix of {@code query},
     *     or {@code null} if no such string
     * @throws IllegalArgumentException if {@code query} is {@code null}
     */
    public String longestPrefixOf(String query) {
        if (query == null) throw new IllegalArgumentException("argument to longestPrefixOf() is null");
        int length = longestPrefixOf(root, query, 0, -1);
        if (length == -1) return null;
        else return query.substring(0, length);
    }

    // returns the length of the longest string key in the subtrie
    // rooted at x that is a prefix of the query string,
    // assuming the first d character match and we have already
    // found a prefix match of given length (-1 if no such match)
    private int longestPrefixOf(Node x, String query, int d, int length) {
        if (x == null) return length;
        if (x.val != null) length = d;
        if (d == query.length()) return length;
        char c = query.charAt(d);
        return longestPrefixOf(x.next[c], query, d+1, length);
    }

    /**
     * Removes the key from the set if the key is present.
     * @param key the key
     * @throws IllegalArgumentException if {@code key} is {@code null}
     */
    public void delete(String key) {
        if (key == null) throw new IllegalArgumentException("argument to delete() is null");
        root = delete(root, key, 0);
    }

    private Node delete(Node x, String key, int d) {
        if (x == null) return null;
        if (d == key.length()) {
            if (x.val != null) n--;
            x.val = null;
        }
        else {
            char c = key.charAt(d);
            x.next[c] = delete(x.next[c], key, d+1);
        }

        // remove subtrie rooted at x if it is completely empty
        if (x.val != null) return x;
        for (int c = 0; c < R; c++)
            if (x.next[c] != null)
                return x;
        return null;
    }

    /**
     * Unit tests the {@code TrieST} data type.
     *
     * @param args the command-line arguments
     */
    public static void main(String[] args) {

        // build symbol table from standard input
        TrieST st = new TrieST();
        for (int i = 0; !StdIn.isEmpty(); i++) {
            String key = StdIn.readString();
            st.put(key, i);
        }

        // print results
        if (st.size() < 100) {
            StdOut.println("keys(\"\"):");
            for (String key : st.keys()) {
                StdOut.println(key + " " + st.get(key));
            }
            StdOut.println();
        }

        StdOut.println("longestPrefixOf(\"shellsort\"):");
        StdOut.println(st.longestPrefixOf("shellsort"));
        StdOut.println();

        StdOut.println("longestPrefixOf(\"quicksort\"):");
        StdOut.println(st.longestPrefixOf("quicksort"));
        StdOut.println();

        StdOut.println("keysWithPrefix(\"shor\"):");
        for (String s : st.keysWithPrefix("shor"))
            StdOut.println(s);
        StdOut.println();

        StdOut.println("keysThatMatch(\".he.l.\"):");
        for (String s : st.keysThatMatch(".he.l."))
            StdOut.println(s);
    }
}

三、性能分析

  • 时间复杂度
    Trie树的形状与键的插入(删除)顺序无关。
    查找效率仅与树的高度有关,而树的高度由键的长度决定。
    当字母表的大小为R,在一棵由N个随机键构造的单词查找树中,未命中查找平均所需检查的结点数量~logRN

  • 空间复杂度
    R向单词查找树中,链接总数在RN~RNw之间
    (R:字母表大小,N:键总数,w:键平均长度)
    故R向单词查找树不适合处理字母表R很大的键。

四、三向单词查找树

4.1 定义

在R向单词查找树中,当字母表R很大时,会消耗大量空间。可以通过一种称为“三向单词查找树”的数据结构进行优化。

在三向单词查找树中,每个结点都含有一个字符、三个链接、一个值。这三条链接分别对应小于、等于和大于结点字符的所有键,只有在沿着中间链接前进时才会根据字符找到表中的键。
这种实现方式相当于将R向单词查找树中的每个结点实现为以非空链接所对应的字符作为键的二叉查找树。

Trie树_第7张图片
4-1 三向单词查找树和R向单词查找树的对应关系

数据结构定义:

public class TST {
    private int n;           // size
    private Node root;    // root of TST
    private static class Node {
        private char c;                            
        private Node left, mid, right;        
        private V val;                            
    }
}

4.2 实现

4.2.1 查找

查找步骤:

  1. 比较键的首字符与树的根结点字符的大小。
    如果键首字符较小,则选择左链接;
    如果较大,则选择右链接;
    如果相等,则选择中链接。
  2. 递归地重复步骤1;
  3. 直到遇到一个空链接或到达键的末尾。
    如果为空链接,则未命中;
    如果到达键的末尾,且结点值为空,则未命中;
    如果到达键的末尾,且结点值非空,则命中。
Trie树_第8张图片
4-2-1 查找示意图

查找-源码实现:

public V get(String key) {
    if (key == null)
        throw new IllegalArgumentException("calls get() with null argument");
    if (key.length() == 0)
        throw new IllegalArgumentException("key must have length >= 1");
    Node x = get(root, key, 0);
    if (x == null)
        return null;
    return x.val;
}
// 在以x为根结点的树中,查找键key[d]
private Node get(Node x, String key, int d) {
    if (x == null)
        return null;
    if (key.length() == 0)
        throw new IllegalArgumentException("key must have length >= 1");
    char c = key.charAt(d);
    if (c < x.c)
        return get(x.left, key, d);
    else if (c > x.c)
        return get(x.right, key, d);
    else {
        if (d < key.length() - 1)
            return get(x.mid, key, d + 1);
        else
            return x;
    }
}
4.2.2 插入

插入-源码实现:

public void put(String key, V val) {
    if (key == null) {
        throw new IllegalArgumentException("calls put() with null key");
    }
    if (!contains(key))
        n++;
    root = put(root, key, val, 0);
}
//在以x为根结点的树中,插入结点key[d],返回新树的根结点
private Node put(Node x, String key, V val, int d) {
    char c = key.charAt(d);                                    
    if (x == null) {
        x = new Node();
        x.c = c;
    }
    if (c < x.c)
        x.left = put(x.left, key, val, d);
    else if (c > x.c)
        x.right = put(x.right, key, val, d);
    else {
        if (d < key.length() - 1)
            x.mid = put(x.mid, key, val, d + 1);
        else
            x.val = val;    
    }
    return x;
}
4.2.3 完整源码
public class TST {
    private int n;              // size
    private Node root;   // root of TST
    private static class Node {
        private char c;                        // character
        private Node left, mid, right;  // left, middle, and right subtries
        private Value val;                     // value associated with string
    }

    public TST() {
    }

    /**
     * Returns the number of key-value pairs in this symbol table.
     * @return the number of key-value pairs in this symbol table
     */
    public int size() {
        return n;
    }

    /**
     * Does this symbol table contain the given key?
     * @param key the key
     * @return {@code true} if this symbol table contains {@code key} and
     *     {@code false} otherwise
     * @throws IllegalArgumentException if {@code key} is {@code null}
     */
    public boolean contains(String key) {
        if (key == null) {
            throw new IllegalArgumentException("argument to contains() is null");
        }
        return get(key) != null;
    }

    /**
     * Returns the value associated with the given key.
     * @param key the key
     * @return the value associated with the given key if the key is in the symbol table
     *     and {@code null} if the key is not in the symbol table
     * @throws IllegalArgumentException if {@code key} is {@code null}
     */
    public Value get(String key) {
        if (key == null) {
            throw new IllegalArgumentException("calls get() with null argument");
        }
        if (key.length() == 0) throw new IllegalArgumentException("key must have length >= 1");
        Node x = get(root, key, 0);
        if (x == null) return null;
        return x.val;
    }

    // return subtrie corresponding to given key
    private Node get(Node x, String key, int d) {
        if (x == null) return null;
        if (key.length() == 0) throw new IllegalArgumentException("key must have length >= 1");
        char c = key.charAt(d);
        if      (c < x.c)              return get(x.left,  key, d);
        else if (c > x.c)              return get(x.right, key, d);
        else if (d < key.length() - 1) return get(x.mid,   key, d+1);
        else                           return x;
    }

    /**
     * Inserts the key-value pair into the symbol table, overwriting the old value
     * with the new value if the key is already in the symbol table.
     * If the value is {@code null}, this effectively deletes the key from the symbol table.
     * @param key the key
     * @param val the value
     * @throws IllegalArgumentException if {@code key} is {@code null}
     */
    public void put(String key, Value val) {
        if (key == null) {
            throw new IllegalArgumentException("calls put() with null key");
        }
        if (!contains(key)) n++;
        root = put(root, key, val, 0);
    }

    private Node put(Node x, String key, Value val, int d) {
        char c = key.charAt(d);
        if (x == null) {
            x = new Node();
            x.c = c;
        }
        if      (c < x.c)               x.left  = put(x.left,  key, val, d);
        else if (c > x.c)               x.right = put(x.right, key, val, d);
        else if (d < key.length() - 1)  x.mid   = put(x.mid,   key, val, d+1);
        else                            x.val   = val;
        return x;
    }

    /**
     * Returns the string in the symbol table that is the longest prefix of {@code query},
     * or {@code null}, if no such string.
     * @param query the query string
     * @return the string in the symbol table that is the longest prefix of {@code query},
     *     or {@code null} if no such string
     * @throws IllegalArgumentException if {@code query} is {@code null}
     */
    public String longestPrefixOf(String query) {
        if (query == null) {
            throw new IllegalArgumentException("calls longestPrefixOf() with null argument");
        }
        if (query.length() == 0) return null;
        int length = 0;
        Node x = root;
        int i = 0;
        while (x != null && i < query.length()) {
            char c = query.charAt(i);
            if      (c < x.c) x = x.left;
            else if (c > x.c) x = x.right;
            else {
                i++;
                if (x.val != null) length = i;
                x = x.mid;
            }
        }
        return query.substring(0, length);
    }

    /**
     * Returns all keys in the symbol table as an {@code Iterable}.
     * To iterate over all of the keys in the symbol table named {@code st},
     * use the foreach notation: {@code for (Key key : st.keys())}.
     * @return all keys in the symbol table as an {@code Iterable}
     */
    public Iterable keys() {
        Queue queue = new Queue();
        collect(root, new StringBuilder(), queue);
        return queue;
    }

    /**
     * Returns all of the keys in the set that start with {@code prefix}.
     * @param prefix the prefix
     * @return all of the keys in the set that start with {@code prefix},
     *     as an iterable
     * @throws IllegalArgumentException if {@code prefix} is {@code null}
     */
    public Iterable keysWithPrefix(String prefix) {
        if (prefix == null) {
            throw new IllegalArgumentException("calls keysWithPrefix() with null argument");
        }
        Queue queue = new Queue();
        Node x = get(root, prefix, 0);
        if (x == null) return queue;
        if (x.val != null) queue.enqueue(prefix);
        collect(x.mid, new StringBuilder(prefix), queue);
        return queue;
    }

    // all keys in subtrie rooted at x with given prefix
    private void collect(Node x, StringBuilder prefix, Queue queue) {
        if (x == null) return;
        collect(x.left,  prefix, queue);
        if (x.val != null) queue.enqueue(prefix.toString() + x.c);
        collect(x.mid,   prefix.append(x.c), queue);
        prefix.deleteCharAt(prefix.length() - 1);
        collect(x.right, prefix, queue);
    }


    /**
     * Returns all of the keys in the symbol table that match {@code pattern},
     * where . symbol is treated as a wildcard character.
     * @param pattern the pattern
     * @return all of the keys in the symbol table that match {@code pattern},
     *     as an iterable, where . is treated as a wildcard character.
     */
    public Iterable keysThatMatch(String pattern) {
        Queue queue = new Queue();
        collect(root, new StringBuilder(), 0, pattern, queue);
        return queue;
    }
 
    private void collect(Node x, StringBuilder prefix, int i, String pattern, Queue queue) {
        if (x == null) return;
        char c = pattern.charAt(i);
        if (c == '.' || c < x.c) collect(x.left, prefix, i, pattern, queue);
        if (c == '.' || c == x.c) {
            if (i == pattern.length() - 1 && x.val != null) queue.enqueue(prefix.toString() + x.c);
            if (i < pattern.length() - 1) {
                collect(x.mid, prefix.append(x.c), i+1, pattern, queue);
                prefix.deleteCharAt(prefix.length() - 1);
            }
        }
        if (c == '.' || c > x.c) collect(x.right, prefix, i, pattern, queue);
    }

    /**
     * Unit tests the {@code TST} data type.
     *
     * @param args the command-line arguments
     */
    public static void main(String[] args) {

        // build symbol table from standard input
        TST st = new TST();
        for (int i = 0; !StdIn.isEmpty(); i++) {
            String key = StdIn.readString();
            st.put(key, i);
        }

        // print results
        if (st.size() < 100) {
            StdOut.println("keys(\"\"):");
            for (String key : st.keys()) {
                StdOut.println(key + " " + st.get(key));
            }
            StdOut.println();
        }

        StdOut.println("longestPrefixOf(\"shellsort\"):");
        StdOut.println(st.longestPrefixOf("shellsort"));
        StdOut.println();

        StdOut.println("longestPrefixOf(\"shell\"):");
        StdOut.println(st.longestPrefixOf("shell"));
        StdOut.println();

        StdOut.println("keysWithPrefix(\"shor\"):");
        for (String s : st.keysWithPrefix("shor"))
            StdOut.println(s);
        StdOut.println();

        StdOut.println("keysThatMatch(\".he.l.\"):");
        for (String s : st.keysThatMatch(".he.l."))
            StdOut.println(s);
    }
}

4.3 性能分析

三向单词查找树是R向单词查找树的紧凑表示,其每个结点只含有3条链接。

  • 时间复杂度
    查找未命中平均需要比较~InN次。
  • 空间复杂度
    三向单词查找树中,链接总数在3N~3Nw之间
    (N:键总数,w:键平均长度)

五、各类字符串查找算法比较

在空间足够的情况下,R向单词查找树的速度是最快的,能够在常数次字符比较内完成查找。
但是对于大型字母表,R向单词查找树的空间通常无法满足需求,此时三向单词查找树是较好的选择,它对字符的比较次数是对数级别的。

Trie树_第9张图片
5-1 各类字符串查找算法的比较

你可能感兴趣的:(Trie树)