scala 实现trie树匹配

近段时间需要使用trie树来加速like操作,在网上找了一圈发现没有可用的scala实现的trie树于是自己改了一版。

该版本根据java版的trie树改编而来,java版地址:https://www.jianshu.com/p/b5b8f00d0e55

实现代码如下:

import scala.collection.mutable.{HashMap, ListBuffer}

class TrieTree {
  private var root:TrieNode = null

  def this(words: List[String]) {
    this()
    root = new TrieNode
    try words.foreach(addWord)
    catch {
      case e: Exception =>
        e.printStackTrace()
    }
  }

  /**
   * 向树中添加单词
   * @param word
   */
  def addWord(word: String): Unit = {
    if (word == null || word.length == 0) return
    // 每次调用addWord都重新拿到全局根节点对象
    var current = root
    for(i <- 0 until word.length){
      val code = word.charAt(i)
      current = current.add(code)
    }
    current.end = true
  }

  /**
   * 判断文本中是否有树中可匹配的词
   * @param text 要参与匹配的文本
   * @return
   */
  def isContains(text: String): Boolean = {
    if (text == null || text.length == 0) return false
    // 获得前缀树
    var current = root
    // 从词的首位开始遍历
    var index = 0
    while (index < text.length){
      // 如果在当前层找到当前字母,继续往下一层找
      if (current.child.getOrElse(text.charAt(index),null) != null)
        current = current.child.getOrElse(text.charAt(index),null)
      else
      { // 如果在当前这一层找不到字符子节点,直接切到新的root该子节点下重新找
        // 如果root下也没有该字母,继续返回root给下一个字母调用防止空指针
        current = if (root.child.getOrElse(text.charAt(index),null) == null) root
        else root.child.getOrElse(text.charAt(index),null)
      }
      // 判断是否存在的依据: 当前查找返回的节点对象是否是end标志
      if (current.end) return true
      index += 1
    }
    false
  }

  /**
   * 获得匹配到的第一个词
   * @param text 要进行匹配的文本
   * @return
   */
  def getContainsItemOne(text: String): String = {
    if (text == null || text.length == 0) return null
    var current = root
    var index = 0
    var startIndex = 0
    var endIndex = 0
    while (index < text.length){
      if (current.child.getOrElse(text.charAt(index),null) != null) current = current.child.getOrElse(text.charAt(index),null)
      else { // startIndexstartIndex在else条件中更新
        // 有两种情况,如果在根节点都找不到当前字则从index+1开始,如果根节点存在该字,从index开始
        if (root.child.getOrElse(text.charAt(index),null) == null) {
          current = root
          startIndex = index + 1
        }
        else {
          current = root.child.getOrElse(text.charAt(index),null)
          startIndex = index
        }
      }
      if (current.end) {
        endIndex = index
        return text.substring(startIndex, endIndex + 1)
      }
      index += 1
    }
    null
  }

  /**
   * 获得可以配置到的所有词
   * @param text 要匹配的文本
   * @return
   */
  def getContainsItemAll(text: String): List[String] = {
    val res = ListBuffer[String]()
    if (text == null || text.length == 0) return res.toList
    var current = root
    var index = 0
    var startIndex = 0
    var endIndex = 0
    while (index < text.length) {
      if (current.child.getOrElse(text.charAt(index),null) != null) current = current.child.getOrElse(text.charAt(index),null)
      else { // startIndexstartIndex在else条件中更新
        // 有两种情况,如果在根节点都找不到当前字则从index+1开始,如果根节点存在该字,从index开始
        if (root.child.getOrElse(text.charAt(index),null) == null) {
          current = root
          startIndex = index + 1
        }
        else {
          current = root.child.getOrElse(text.charAt(index),null)
          startIndex = index
        }
      }
      if (current.end) {
        endIndex = index
        res += text.substring(startIndex, endIndex + 1)
        if (current.isLeaf) { // 重置为root
          current = root
          // 重置startIndex
          startIndex = endIndex + 1
        }
      }
      index += 1
    }
    res.toList
  }

  /**
   * 一个节点对象
   * value: 当前节点存储的字母
   * child: 当前节点的子节点信息 字母 -> 节点对象
   * end: 是否是整个词的结尾
   */
   private class TrieNode{
    var value = 0
    var child: HashMap[Char, TrieNode] = null
    var end = false

    def add(newChar: Char):TrieNode = {
      if (child == null) this.child = new HashMap[Char, TrieNode]
      // 找到对应字符的字典树
      var t = child.getOrElse(newChar,null)
      // 在map中查找是否已经存在字母
      if (t == null) { // 不存在则新建一个节点对象
        t = new TrieNode
        // 给节点对象命名为该字母
        t.value = newChar
        child.put(newChar, t)
      }
      // 返回下一个节点
      t
    }

    //判断本节点是否为叶子节点
    def isLeaf: Boolean = child == null
  }
}

object TrieTree{
  def apply(words: List[String]): TrieTree = {
    new TrieTree(words)
  }

  def main(args: Array[String]): Unit = {
    val trieTree = TrieTree(List("区块链", "智慧城市", "央行", "是吗", "区块"))
    val res0 = trieTree.isContains("吃饱了")
    val res1 = trieTree.isContains("吃饱了是吗")
    val res2 = trieTree.getContainsItemOne("我是央行区块链之王")
    val res3 = trieTree.getContainsItemAll("我是央行区块链之王")
    println(res0) //false
    println(res1) //true
    println(res2) //央行
    println(res3) //List(央行, 区块, 区块链)


  }
}

你可能感兴趣的:(scala 实现trie树匹配)