#ifndef TRIENODE_H #define TRIENODE_H #include <string> #include <malloc.h> #include <cstring> using namespace std; static const int MAX = 27; static const char BASE = 'a'; struct TrieNode { TrieNode(const char *word); ~TrieNode(); TrieNode *next[MAX]; char *word; static int digit(char *word, int w); }; #endif // TRIENODE_H
#include "trienode.h" TrieNode::TrieNode(const char *word) { for (int i = 0; i < MAX; i++) { next[i] = NULL; } this->word = NULL; if (word != NULL) { int len = strlen(word); this->word = new char[len+1]; strcpy(this->word, word); } } TrieNode::~TrieNode() { if (word != NULL) delete[] word; } int TrieNode::digit(char *word, int w) { // 注意:这里默认使用0号位用于存储字符\0,这是为了解决一个字符串 // 是另外一个字符串的前缀的问题。 return (word[w] == '\0') ? 0 : word[w]-BASE+1; }
#ifndef TRIE_H #define TRIE_H #include "trienode.h" #include <iostream> #include <vector> using namespace std; class Trie { public: Trie(); ~Trie(); // 插入一个新的单词 void insert(char *word); // 打印字典树 void printTrie(); // 查看是否包含某个单词 bool contains(char *word); // 查找某个单词,返回这个单词指针,好像没什么用 const char* find(char *word); // 匹配单词前缀,获得所有匹配结果 int match(char *word, vector<char *> &ret); // 删除一个单词,静默吞噬错误异常 void remove(char *word); // 清空字典树 void clear(); // 检查字典树是否为空 bool empty(); // 获取匹配前缀的前n个匹配项 int match(char *word, int n, vector<char *> &ret); private: static TrieNode* insertR(TrieNode *root, char *word, int w); static TrieNode* splitR(TrieNode *p, TrieNode *q, int w); static void printTrieR(TrieNode *root); static TrieNode* findR(TrieNode *root, char *word, int w); static int matchR(TrieNode *root, char *word, int w, vector<char *> &ret); static int getWords(TrieNode *root, vector<char *> &ret); static TrieNode* matchPrefixR(TrieNode *root, char *word, int w); static bool removeR(TrieNode *root, char *word, int w); static void clearR(TrieNode *root); static int getWords(TrieNode *root, int n, vector<char *> &ret); static int matchR(TrieNode *root, char *word, int w, int n, vector<char *> &ret); private: TrieNode *root; }; #endif // TRIE_H
#include "trie.h" Trie::Trie() { root = NULL; } Trie::~Trie() { clear(); } void Trie::insert(char *word) { if (word == NULL) return; this->root = insertR(this->root, word, 0); } void Trie::printTrie() { printTrieR(this->root); } bool Trie::contains(char *word) { if (word == NULL) return false; return (findR(this->root, word, 0) == NULL) ? false : true; } const char *Trie::find(char *word) { if (word == NULL) return NULL; TrieNode *n = findR(this->root, word, 0); return (n == NULL) ? NULL : n->word; } int Trie::match(char *word, vector<char *> &ret) { return (word == NULL) ? 0 : matchR(this->root, word, 0, ret); } void Trie::remove(char *word) { removeR(this->root, word, 0); } void Trie::clear() { clearR(this->root); this->root = NULL; } bool Trie::empty() { return (root == NULL) ? true : false; } int Trie::match(char *word, int n, vector<char *> &ret) { return (word == NULL) ? 0 : matchR(root, word, 0, n, ret); } TrieNode *Trie::insertR(TrieNode *root, char *word, int w) { if (root == NULL) { return new TrieNode(word); } // 为叶子节点的情况 if (root->word != NULL) { return splitR(root, new TrieNode(word), w); } // 不为叶子节点的情况 int idx = TrieNode::digit(word, w); if (root->word == NULL) { root->next[idx] = insertR(root->next[idx], word, w+1); } return root; } TrieNode *Trie::splitR(TrieNode *p, TrieNode *q, int w) { // 相同单词情况 if (p->word[w] == '\0' && q->word[w] == '\0') { delete q; return p; } TrieNode *t = new TrieNode(NULL); int pIdx = TrieNode::digit(p->word, w); int qIdx = TrieNode::digit(q->word, w); if (pIdx != qIdx) { // 字符不同 t->next[pIdx] = p; t->next[qIdx] = q; } else { // 字符相同 t->next[pIdx] = splitR(p, q, w+1); } return t; } void Trie::printTrieR(TrieNode *root) { if (root == NULL) return; if (root->word != NULL) cout << root->word << endl; for (int i = 0; i < MAX; i++) { printTrieR(root->next[i]); } } TrieNode *Trie::findR(TrieNode *root, char *word, int w) { if (root == NULL || (word[w] == '\0' && root->word == NULL)) return NULL; // 找到单词 if (root->word != NULL && strcmp(root->word+w-1, word+w-1) == 0) { return root; } return findR(root->next[TrieNode::digit(word, w)], word, w+1); } int Trie::matchR(TrieNode *root, char *word, int w, vector<char *> &ret) { // 找到前缀匹配的子树 TrieNode *subTree = matchPrefixR(root, word, w); // 构建结果 return (subTree == NULL) ? 0 : getWords(subTree, ret); } int Trie::getWords(TrieNode *root, vector<char *> &ret) { if (root == NULL) return 0; // 叶子节点 if (root->word != NULL) { ret.push_back(root->word); return 1; } int count = 0; for (int i = 0; i < MAX; i++) { if (root->next[i] != NULL) count += getWords(root->next[i], ret); } return count; } TrieNode *Trie::matchPrefixR(TrieNode *root, char *word, int w) { if (root == NULL) return NULL; if (word[w] == '\0' || word[w+1] == '\0') return root; return matchPrefixR(root->next[TrieNode::digit(word, w)], word, w+1); } bool Trie::removeR(TrieNode *root, char *word, int w) { if (root == NULL) return false; if (root->word != NULL && strcmp(root->word+w-1, word+w-1) == 0) { delete root; return true; } int idx = TrieNode::digit(word, w); if (removeR(root->next[idx], word, w+1) != true) return false; root->next[idx] = NULL; for (int i = 0; i < MAX; i++) { if (root->next[i] != NULL) { return false; } } delete root; return true; } void Trie::clearR(TrieNode *root) { if (root == NULL) return; for (int i = 0; i < MAX; i++) { if (root->next[i] != NULL) { clearR(root->next[i]); } } delete root; } int Trie::getWords(TrieNode *root, int n, vector<char *> &ret) { if (root == NULL || n <= 0) return 0; if (root->word != 0) { ret.push_back(root->word); return 1; } int count = 0; for (int i = 0; i < MAX && n > 0; i++) { int nWords = getWords(root->next[i], n, ret); n -= nWords; count += nWords; } return count; } int Trie::matchR(TrieNode *root, char *word, int w, int n, vector<char *> &ret) { TrieNode *subTree = matchPrefixR(root, word, w); return (subTree == NULL) ? 0 : getWords(subTree, n, ret); }