双数组字典树的java实现

双数组字典树的算法思想这里就不在详述,有兴趣的可以自己谷歌一下。

废话少说,java代码如下:

 

 

/**
 *
 */
package com.kongfz.service.banned.check;

/**
 * 双数组字典树查找敏感词算法
 *
 * 读代码前,请先了解字典树和双数组字典树算法思想
 * @author Administrator
 *
 */
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

public class DoubleArrayTrie {
   
    /** 分词结束符 */
    private final char END_CHAR = '\0';
   
    /** 分配步长 */
    private final int DEFAULT_LEN = 1024;
   
    /** 基础位置数组 */
    private int base[] = new int[DEFAULT_LEN];
   
    /** 前一状态数组 */
    private int check[] = new int[DEFAULT_LEN];
   
    /** 词语结尾字数组 */
    private char tail[] = new char[DEFAULT_LEN];
   
    /** 开始位置 */
    int pos = 1;
   
    /** 字典字符和位置对应关系 */
    Map<Character, Integer> charMap = new HashMap<Character, Integer>();
   
    /** 字典字符列表 */
    ArrayList<Character> charList = new ArrayList<Character>();

    /**
     * 构造函数
     */
    public DoubleArrayTrie() {
        base[1] = 1;

        charMap.put(END_CHAR, 1);
        charList.add(END_CHAR);
        charList.add(END_CHAR);
        for (int i = 0; i < 26; ++i) {
            charMap.put((char) ('a' + i), charMap.size() + 1);
            charList.add((char) ('a' + i));
        }
    }

    /**
     * 扩充位置数组和状态数组
     */
    private void Extend_Array() {
        base = Arrays.copyOf(base, base.length * 2);
        check = Arrays.copyOf(check, check.length * 2);
    }

    /**
     * 扩充结尾字符数组
     */
    private void Extend_Tail() {
        tail = Arrays.copyOf(tail, tail.length * 2);
    }

    /**
     * 从字符关系map中获取字符位置
     * 不存在时添加
     * @param c
     * @return
     */
    private int getAndAddCharCode(char c) {
        if (!charMap.containsKey(c)) {
            charMap.put(c, charMap.size() + 1);
            charList.add(c);
        }
        return charMap.get(c);
    }
   
    /**
     * 从字符集map中取得指定字符位置
     * @param c
     * @return
     */
    private int getCharCode(char c) {
        if (!charMap.containsKey(c)) {
            return -1;
        }
        return charMap.get(c);
    }

    /**
     * 复制字符到词语结尾字数组
     * @param s
     * @param p
     * @return
     */
    private int copyToTail(String s, int p) {
        int _pos = pos;
        while (s.length() - p + 1 > tail.length - pos) {
            Extend_Tail();
        }
        for (int i = p; i < s.length(); ++i) {
            tail[_pos] = s.charAt(i);
            _pos++;
        }
        return _pos;
    }

    /**
     * 冲突时计算下一个空闲的位置
     * @param set
     * @return
     */
    private int x_check(Integer[] set) {
        for (int i = 1;; ++i) {
            boolean flag = true;
            for (int j = 0; j < set.length; ++j) {
                int cur_p = i + set[j];
                if (cur_p >= base.length)
                    Extend_Array();
                if (base[cur_p] != 0 || check[cur_p] != 0) {
                    flag = false;
                    break;
                }
            }
            if (flag)
                return i;
        }
    }

    /**
     * 取得所有同义词
     * @param p
     * @return
     */
    private ArrayList<Integer> getChildList(int p) {
        ArrayList<Integer> ret = new ArrayList<Integer>();
        for (int i = 1; i <= charMap.size(); ++i) {
            if (base[p] + i >= check.length)
                break;
            if (check[base[p] + i] == p) {
                ret.add(i);
            }
        }
        return ret;
    }

    /**
     * 判断结尾字数组中是否包含某字符
     * @param start
     * @param s2
     * @return
     */
    private boolean tailContainString(int start, String s2) {
        for (int i = 0; i < s2.length(); ++i) {
            if (s2.charAt(i) != tail[i + start])
                return false;
        }
        return true;
    }

    private boolean tailMatchString(int start, String s2) {
        s2 += END_CHAR;
        for (int i = 0; i < s2.length(); ++i) {
            if (s2.charAt(i) != tail[i + start])
                return false;
        }
        return true;
    }

    /**
     * 向字典中插入词
     * @param word
     * @throws Exception
     */
    public void insertWord(String word) throws Exception {
        word += END_CHAR;
        int pre_p = 1;
        int cur_p;
        for (int i = 0; i < word.length(); ++i) {
            // 获取状态位置
            cur_p = base[pre_p] + getAndAddCharCode(word.charAt(i));
            // 如果长度超过现有,拓展数组
            if (cur_p >= base.length){
                Extend_Array();
            }
            // 空闲状态
            if (base[cur_p] == 0 && check[cur_p] == 0) {
                base[cur_p] = -pos;
                check[cur_p] = pre_p;
                pos = copyToTail(word, i + 1);
                break;
            } else {
                // 已存在状态
                if (base[cur_p] > 0 && check[cur_p] == pre_p) {
                    pre_p = cur_p;
                    continue;
                } else {
                    // 冲突 1:遇到 Base[cur_p]小于0的,即遇到一个被压缩存到Tail中的字符串
                    if (base[cur_p] < 0 && check[cur_p] == pre_p) {
                        int head = -base[cur_p];
                        // 插入重复字符串
                        if (word.charAt(i + 1) == END_CHAR && tail[head] == END_CHAR) {
                            break;
                        }
                        // 公共字母的情况,因为上一个判断已经排除了结束符,所以一定是2个都不是结束符
                        if (tail[head] == word.charAt(i + 1)) {
                            int avail_base = x_check(new Integer[] { getAndAddCharCode(word.charAt(i + 1)) });
                            base[cur_p] = avail_base;
                            check[avail_base + getAndAddCharCode(word.charAt(i + 1))] = cur_p;
                            base[avail_base + getAndAddCharCode(word.charAt(i + 1))] = -(head + 1);
                            pre_p = cur_p;
                            continue;
                        } else {
                            // 2个字母不相同的情况,可能有一个为结束符
                            int avail_base;
                            avail_base = x_check(new Integer[] {
                                    getAndAddCharCode(word.charAt(i + 1)),
                                    getAndAddCharCode(tail[head]) });
                            base[cur_p] = avail_base;
                            check[avail_base + getAndAddCharCode(tail[head])] = cur_p;
                            check[avail_base + getAndAddCharCode(word.charAt(i + 1))] = cur_p;
                            // Tail 为END_FLAG 的情况
                            if (tail[head] == END_CHAR) {
                                base[avail_base + getAndAddCharCode(tail[head])] = 0;
                            } else {
                                base[avail_base + getAndAddCharCode(tail[head])] = -(head + 1);
                            }
                            if (word.charAt(i + 1) == END_CHAR) {
                                base[avail_base + getAndAddCharCode(word.charAt(i + 1))] = 0;
                            } else {
                                base[avail_base + getAndAddCharCode(word.charAt(i + 1))] = -pos;
                            }
                            pos = copyToTail(word, i + 2);
                            break;
                        }
                    } else {
                        // 冲突2:当前结点已经被占用,需要调整pre的base
                        if (check[cur_p] != pre_p) {
                            ArrayList<Integer> list1 = getChildList(pre_p);
                            int toBeAdjust;
                            ArrayList<Integer> list = null;
                            if (true) {
                                toBeAdjust = pre_p;
                                list = list1;
                            }
                            int origin_base = base[toBeAdjust];
                            list.add(getAndAddCharCode(word.charAt(i)));
                            int avail_base = x_check((Integer[]) list.toArray(new Integer[list.size()]));
                            list.remove(list.size() - 1);
                            base[toBeAdjust] = avail_base;
                            for (int j = 0; j < list.size(); ++j) {
                                // BUG
                                int tmp1 = origin_base + list.get(j);
                                int tmp2 = avail_base + list.get(j);
                                base[tmp2] = base[tmp1];
                                check[tmp2] = check[tmp1];
                                // 有后续
                                if (base[tmp1] > 0) {
                                    ArrayList<Integer> subsequence = getChildList(tmp1);
                                    for (int k = 0; k < subsequence.size(); ++k) {
                                        check[base[tmp1] + subsequence.get(k)] = tmp2;
                                    }
                                }
                                base[tmp1] = 0;
                                check[tmp1] = 0;
                            }
                            // 更新新的cur_p
                            cur_p = base[pre_p] + getAndAddCharCode(word.charAt(i));
                            if (word.charAt(i) == END_CHAR) {
                                base[cur_p] = 0;
                            } else {
                                base[cur_p] = -pos;
                            }
                            check[cur_p] = pre_p;
                            pos = copyToTail(word, i + 1);
                            break;
                        }
                    }
                }
            }
        }
    }

    /**
     * 查找词典中是否包含某个词语
     * @param word
     * @return
     */
    public boolean Exists(String word) {
        int pre_p = 1;
        int cur_p = 0;

        for (int i = 0; i < word.length(); ++i) {
            cur_p = base[pre_p] + getAndAddCharCode(word.charAt(i));
            if (check[cur_p] != pre_p)
                return false;
            if (base[cur_p] < 0) {
                if (tailMatchString(-base[cur_p], word.substring(i + 1)))
                    return true;
                return false;
            }
            pre_p = cur_p;
        }
        if (check[base[cur_p] + getAndAddCharCode(END_CHAR)] == cur_p)
            return true;
        return false;
    }

    // 内部函数,返回匹配单词的最靠后的Base index,
    class FindStruct {
        int p;
        String prefix = "";
    }

    /**
     * 从词典中匹配存在的词语
     * @param word
     * @return
     */
    private FindStruct Find(String word) {
        int pre_p = 1;
        int cur_p = 0;
        FindStruct fs = new FindStruct();
        for (int i = 0; i < word.length(); ++i) {
            // BUG
            fs.prefix += word.charAt(i);
            cur_p = base[pre_p] + getCharCode(word.charAt(i));
            //字典树中不包含此字符开头的词语
            if (check[cur_p] != pre_p) {
                fs = new FindStruct();
                pre_p = 1;
                cur_p = 0;
                continue;
            }
            if (base[cur_p] < 0) {
                if (tailContainString(-base[cur_p], "")){
                    fs.p = cur_p;
                    return fs;
                }
                pre_p = 1;
                cur_p = 0;
                fs = new FindStruct();
                continue;
            }
            pre_p = cur_p;
        }
        fs.p = cur_p;
        return fs;
    }

    /**
     * 取得指定词语的为前缀的所有词语
     * @param index
     * @return
     */
    public ArrayList<String> GetAllChildWord(int index) {
        ArrayList<String> result = new ArrayList<String>();
        if (base[index] == 0) {
            result.add("");
            return result;
        }
        if (base[index] < 0) {
            String r = "";
            for (int i = -base[index]; tail[i] != END_CHAR; ++i) {
                r += tail[i];
            }
            result.add(r);
            return result;
        }
        for (int i = 1; i <= charMap.size(); ++i) {
            if (check[base[index] + i] == index) {
                for (String s : GetAllChildWord(base[index] + i)) {
                    result.add(charList.get(i) + s);
                }
                // result.addAll(GetAllChildWord(Base[index]+i));
            }
        }
        return result;
    }

    public ArrayList<String> findBannedWord(String word) {
        ArrayList<String> result = new ArrayList<String>();
//        String prefix = "";
        FindStruct fs = Find(word);
        int p = fs.p;
        if (p == -1)
            return result;
        if (base[p] < 0) {
            String r = "";
            for (int i = -base[p]; tail[i] != END_CHAR; ++i) {
                r += tail[i];
            }
            result.add(fs.prefix + r);
            return result;
        }
//
//        if (Base[p] > 0) {
//            ArrayList<String> r = GetAllChildWord(p);
//            for (int i = 0; i < r.size(); ++i) {
//                r.set(i, fs.prefix + r.get(i));
//            }
//            return r;
//        }
        result.add(fs.prefix);
        return result;
    }
   
    /**
     * 删除敏感词
     * @param word
     * @return
     */
    public boolean delBannedWord(String word){
        if(!this.Exists(word)){
            return false;
        }
        int pre_p = 1;
        int cur_p = 0;
        int start = 0;
        findToDel(word, pre_p, cur_p, start);
        return true;
    }
   
    /**
     * 删除敏感词
     * @param word
     * @param pre_p
     * @param cur_p
     * @param start
     * @return
     */
    private boolean findToDel(String word, int pre_p, int cur_p, int start) {
        char key = word.charAt(start);
        cur_p = base[pre_p] + getCharCode(key);
        if (base[cur_p] < 0) {
            if (tailContainString(-base[cur_p], "")){
                for (int i = -base[cur_p]; tail[i] != END_CHAR; ++i) {
                    tail[i] = END_CHAR;
                }
                base[cur_p]=0;
                check[cur_p]=0;
                return true;
            }
            return false;
        }
        pre_p = cur_p;
        findToDel(word, pre_p, cur_p, start+1);
        return true;
    }


    public static void main(String[] args) throws Exception {
        long start = System.currentTimeMillis();
        DoubleArrayTrie dat = new DoubleArrayTrie();
        //加载词库
//        InputStreamReader isr = new InputStreamReader(new FileInputStream(
//                "E:\\workspace\\MyProject\\src\\test\\segment\\dict.txt"),
//                "UTF-8");
//        BufferedReader br = new BufferedReader(isr);
//        String readLine = br.readLine();
//        while (readLine != null) {
//            dat.Insert(readLine);
//            readLine = br.readLine();
//        }
//        isr.close();
//        br.close();
//        System.out.println("The init total time is "
//                + ((System.currentTimeMillis() - start)) + "ms");

//        dat.insertWord("学生本");
        dat.insertWord("学校");
        dat.insertWord("学习");
        dat.insertWord("调查");
        dat.delBannedWord("调查");
        dat.insertWord("调查");
//        dat.insertWord("本地");
        dat.insertWord("北京");
        dat.delBannedWord("学校");
        dat.insertWord("学生");
//        dat.insertWord("学校");
       
       
        System.out.println("The update total time is "
                + ((System.currentTimeMillis() - start)) + "ms");
       
        System.out.println(dat.base.length);
        System.out.println(dat.tail.length);
        System.out.println("The init total time is "
                + ((System.currentTimeMillis() - start)) + "ms");

        String resStr = "圣斗士星 矢在月 宫败在了嫦 娥和玉 兔手下,因为星矢不是本地人,也是个不起眼的角色,goodbye."
                + "京华时报讯(记者 怀若谷 实习记者 常鑫)昨天,本报报道河南省潢川县王香铺村村民蔡先生家林地被烧一事,"
                + "当事双方对起火原因说法不一。昨晚,潢川县森林公安派出所称,起火林地类型为未成林造林地,不予刑事立案。"
                + "潢川县森林公安派出所李警官对京华时报记者表示,他们聘请淮滨县林业局工程师重新对起火地点进行勘测,"
                + "“总过火面积为21991.9平方米,其中有2240平方米范围内仅有零星树木痕迹。该林地类型为未成林造林地。”"
                + "李警官称,未成林造林地起火不牵扯刑事责任,因此不予刑事立案,但具体起火原因仍在调查。";
        System.out.println(dat.findBannedWord(resStr));
        String findStr = "学生是本地人学习";
//        System.out.println(dat.Exists(resStr));
        //检测敏感词,返回第一个存在的敏感词
        System.out.println(findStr + ":::::::" + dat.findBannedWord(findStr));
//       
        findStr = "学校也在本地";
//        System.out.println(dat.Exists(resStr));
        System.out.println(findStr + ":::::::" + dat.findBannedWord(findStr));
//       
        findStr = "本地的学生本";
//        System.out.println(dat.Exists(resStr));
        System.out.println(findStr + ":::::::" + dat.findBannedWord(findStr));
       
        findStr = "北京市的学生";
//        System.out.println(dat.Exists(resStr));
        System.out.println(findStr + ":::::::" + dat.findBannedWord(findStr));
//
//        findStr = "具体的事情";
//        System.out.println(dat.Exists(resStr));
//        System.out.println(findStr + ":::::::" + dat.FindAllWords(findStr));
       
        System.out.println("The total time is "
                + ((System.currentTimeMillis() - start)) + "ms");

    }
}
 

你可能感兴趣的:(java,双数组字典树)