BK树

原文链接:http://blog.jobbole.com/78811/

这是『超酷算法』系列的第一篇文章。基本上,任何一种算法我觉得都很酷,尤其是那些不那么明显简单的算法。

BK树或者称为Burkhard-Keller树,是一种基于树的数据结构,被设计于快速查找近似字符串匹配,比方说拼写检查器,或模糊查找,当搜索”aeek”时能返回”seek”和”peek”。为何BK-Trees这么酷,因为除了穷举搜索,没有其他显而易见的解决方法,并且它能以简单和优雅的方法大幅度提升搜索速度。

BK树在1973年由Burkhard和Keller第一次提出,论文在这《Some approaches to best match file searching》。这是网上唯一的ACM存档,需要订阅。更细节的内容,可以阅读这篇论文《Fast Approximate String Matching in a Dictionary》。

在定义BK树之前,我们需要预先定义一些操作。为了索引和搜索字典,我们需要一种比较字符串的方法。编辑距离( Levenshtein Distance)是一种标准的方法,它用来表示经过插入、删除和替换操作从一个字符串转换到另外一个字符串的最小操作步数。其它字符串函数也同样可接受(比如将调换作为原子操作),只要能满足以下一些条件。

现在我们观察下编辑距离:构造一个度量空间(Metric Space),该空间内任何关系满足以下三条基本条件:

  • d(x,y) = 0 <-> x = y (假如x与y的距离为0,则x=y)
  • d(x,y) = d(y,x) (x到y的距离等同于y到x的距离)
  • d(x,y) + d(y,z) >= d(x,z)

上述条件中的最后一条被叫做三角不等式(Triangle Inequality)。三角不等式表明x到z的路径不可能长于另一个中间点的任何路径(从x到y再到z)。看下三角形,你不可能从一点到另外一点的两侧再画出一条比它更短的边来。

编辑距离符合基于以上三条所构造的度量空间。请注意,有其它更为普遍的空间,比如欧几里得空间(Euclidian Space),编辑距离不是欧几里得的。既然我们了解了编辑距离(或者其它类似的字符串距离函数)所表达的度量的空间,再来看下Burkhard和Keller所观察到的关键结论。

假设现在我们有两个参数,query表示我们搜索的字符串,n表示字符串最大距离,我们可以拿任意字符串test来跟query进行比较。调用距离函数得到距离d,因为我们知道三角不等式是成立的,所以所有结果与test的距离最大为d+n,最小为d-n。

由此,BK树的构造就相当简单:每个节点有任意个子节点,每条边有个值表示编辑距离。所有子节点到父节点的边上标注n表示编辑距离恰好为n。比如,我们有棵树父节点是”book”和两个子节点”rook”和”nooks”,”book”到”rook”的边标号1,”book”到”nooks”的边上标号2。

从字典里构造好树后,取任意单词作为树的根节点。无论何时你想插入新单词时,计算该单词与根节点的编辑距离,并且查找数值为d(neweord, root)的边。递归得与各子节点进行比较,直到没有子节点,你就可以创建新的子节点并将新单词保存在那。比如,插入”boon”到刚才上述例子的树中,我们先检查根节点,查找d(“book”, “boon”) = 1的边,然后检查标号为1的边的子节点,得到单词”rook”。我们再计算距离d(“rook”, “boon”)=2,则将新单词插在”rook”之后,边标号为2。

在树中做查询,计算单词与根节点的编辑距离d,然后递归查找每个子节点标号为d-n到d+n(包含)的边。假如被检查的节点与搜索单词的距离d小于n,则返回该节点并继续查询。

BK树是多路查找树,并且是不规则的(但通常是平衡的)。试验表明,1个查询的搜索距离不会超过树的5-8%,并且2个错误查询的搜索距离不会超过树的17-25%,这可比检查每个节点改进了一大步啊!需要注意的是,如果要进行精确查找,也可以非常有效地通过简单地将n设置为0进行。

回顾这篇文章,写的有点长哈,似乎比我预期中的要复杂。希望你在阅读之后,也能感受到BK树的优雅和简单。



HDU 4323 bk树 编辑距离 原文链接:http://www.cnblogs.com/tangcong/archive/2012/09/10/2679081.htm

http://www.matrix67.com/blog/archives/333

http://www.cnblogs.com/tangcong/archive/2012/09/10/2679081.html

除了字符串匹配、查找回文串、查找重复子串等经典问题以外,日常生活中我们还会遇到其它一些怪异的字符串问题。比如,有时我们需要知道给定的两个字符串“有多像”,换句话说两个字符串的相似度是多少。1965年,俄国科学家Vladimir Levenshtein给字符串相似度做出了一个明确的定义叫做Levenshtein距离,我们通常叫它“编辑距离”。字符串A到B的编辑距离是指,只用插入、删除和替换三种操作,最少需要多少步可以把A变成B。例如,从FAME到GATE需要两步(两次替换),从GAME到ACM则需要三步(删除G和E再添加C)。Levenshtein给出了编辑距离的一般求法,就是大家都非常熟悉的经典动态规划问题。
    在自然语言处理中,这个概念非常重要,例如我们可以根据这个定义开发出一套半自动的校对系统:查找出一篇文章里所有不在字典里的单词,然后对于每个单词,列出字典里与它的Levenshtein距离小于某个数n的单词,让用户选择正确的那一个。n通常取到2或者3,或者更好地,取该单词长度的1/4等等。这个想法倒不错,但算法的效率成了新的难题:查字典好办,建一个Trie树即可;但怎样才能快速在字典里找出最相近的单词呢?这个问题难就难在,Levenshtein的定义可以是单词任意位置上的操作,似乎不遍历字典是不可能完成的。现在很多软件都有拼写检查的功能,提出更正建议的速度是很快的。它们到底是怎么做的呢?1973年,Burkhard和Keller提出的BK树有效地解决了这个问题。这个数据结构强就强在,它初步解决了一个看似不可能的问题,而其原理非常简单。

    首先,我们观察Levenshtein距离的性质。令d(x,y)表示字符串x到y的Levenshtein距离,那么显然:

1. d(x,y) = 0 当且仅当 x=y  (Levenshtein距离为0 <==> 字符串相等)
2. d(x,y) = d(y,x)     (从x变到y的最少步数就是从y变到x的最少步数)
3. d(x,y) + d(y,z) >= d(x,z)  (从x变到z所需的步数不会超过x先变成y再变成z的步数)

    最后这一个性质叫做三角形不等式。就好像一个三角形一样,两边之和必然大于第三边。给某个集合内的元素定义一个二元的“距离函数”,如果这个距离函数同时满足上面说的三个性质,我们就称它为“度量空间”。我们的三维空间就是一个典型的度量空间,它的距离函数就是点对的直线距离。度量空间还有很多,比如Manhattan距离,图论中的最短路,当然还有这里提到的Levenshtein距离。就好像并查集对所有等价关系都适用一样,BK树可以用于任何一个度量空间。

    建树的过程有些类似于Trie。首先我们随便找一个单词作为根(比如GAME)。以后插入一个单词时首先计算单词与根的Levenshtein距离:如果这个距离值是该节点处头一次出现,建立一个新的儿子节点;否则沿着对应的边递归下去。例如,我们插入单词FAME,它与GAME的距离为1,于是新建一个儿子,连一条标号为1的边;下一次插入GAIN,算得它与GAME的距离为2,于是放在编号为2的边下。再下次我们插入GATE,它与GAME距离为1,于是沿着那条编号为1的边下去,递归地插入到FAME所在子树;GATE与FAME的距离为2,于是把GATE放在FAME节点下,边的编号为2。
      
    查询操作异常方便。如果我们需要返回与错误单词距离不超过n的单词,这个错误单词与树根所对应的单词距离为d,那么接下来我们只需要递归地考虑编号在d-n到d+n范围内的边所连接的子树。由于n通常很小,因此每次与某个节点进行比较时都可以排除很多子树。
    举个例子,假如我们输入一个GAIE,程序发现它不在字典中。现在,我们想返回字典中所有与GAIE距离为1的单词。我们首先将GAIE与树根进行比较,得到的距离d=1。由于Levenshtein距离满足三角形不等式,因此现在所有离GAME距离超过2的单词全部可以排除了。比如,以AIM为根的子树到GAME的距离都是3,而GAME和GAIE之间的距离是1,那么AIM及其子树到GAIE的距离至少都是2。于是,现在程序只需要沿着标号范围在1-1到1+1里的边继续走下去。我们继续计算GAIE和FAME的距离,发现它为2,于是继续沿标号在1和3之间的边前进。遍历结束后回到GAME的第二个节点,发现GAIE和GAIN距离为1,输出GAIN并继续沿编号为1或2的边递归下去(那条编号为4的边连接的子树又被排除掉了)……
    实践表明,一次查询所遍历的节点不会超过所有节点的5%到8%,两次查询则一般不会17-25%,效率远远超过暴力枚举。适当进行缓存,减小Levenshtein距离常数n可以使算法效率更高。


[cpp]  view plain copy
  1. #include  
  2. #include  
  3. #include  
  4. #include  
  5. #include  
  6. #include  
  7. #include  
  8. #include  
  9. #include  
  10. #include  
  11. #include  
  12. #include  
  13. using namespace std;  
  14.   
  15. int dp[40][40];  
  16. char s1[100], s2[100], st[10010][30];  
  17. const int inf  = 0x7f7f7f7f;  
  18. //数据结构定义  
  19. struct node  
  20. {  
  21.   char word[30]; //当前结点值  
  22.   node *next[30];  
  23. }root;  
  24.   
  25. node p[100000];  
  26. int num, flag, vnum, fuck;  
  27. mapint>mp;  
  28.   
  29. int f[100000];  
  30.   
  31. void init( )  
  32. {  
  33.   forint i = 0; i < 40; i++)  
  34.        forint j = 0; j < 40; j++)  
  35.             dp[i][j] = inf;        
  36. }  
  37.   
  38. int diff( char *s1, char *s2)  
  39. {  
  40.   init();  
  41.   int x = strlen(s1+1);  
  42.   int y = strlen(s2+1);  
  43.     forint i = 0; i <= x; i++)  
  44.         dp[i][0] = i;  
  45.    forint j = 0; j <= y; j++)  
  46.         dp[0][j] = j;  
  47.    forint i = 1; i <= x; i++)  
  48.    {  
  49.         forint j = 1; j <= y; j++)  
  50.         {  
  51.             
  52.              dp[i][j] = min(min(dp[i-1][j]+1, dp[i][j-1]+1), dp[i-1][j-1]+ !(s1[i]==s2[j]) );  
  53.         }    
  54.     
  55.   }  
  56.   return dp[x][y];   
  57. }   
  58.   
  59. //建树  
  60. void insert(node *q, char *str)  
  61. {  
  62.   node *l = q;  
  63.   while( l )  
  64.   {  
  65.      int dis = diff( l->word, str);  
  66.      if( ! l->next[dis] )  
  67.      {  
  68.         l->next[dis] = &p[num++];  
  69.         strcpy(l->next[dis]->word + 1, str + 1);  
  70.         break;  
  71.      }  
  72.      l = l->next[dis];                 
  73.   }          
  74. }  
  75.   
  76. //查找与单词相差不大于d的单词   
  77. void sfind(node *q, char *str, int d)  
  78. {  
  79.   if( flag )   
  80.       return ;  
  81.   node *l = q;  
  82.   if( l == NULL )  
  83.       return;  
  84.   int dis = diff(str, l->word);  
  85.   if( dis <= d )  
  86.   {  
  87.     fuck++;  
  88.   }  
  89.   forint x = dis-d; x <= dis+d; x++)  
  90.   {    
  91.      if( x >= 0 && x <= 20 && l->next[x] )  
  92.          sfind(l->next[x], str, d);       
  93.   }  
  94.        
  95. }  
  96.   
  97.    
  98. int main( )  
  99. {  
  100.   int N, M, d, cnt, T, abc = 1;  
  101.   char str[1000];  
  102.   scanf("%d",&T);  
  103.   while( T-- )  
  104.   {  
  105.     scanf("%d%d",&N,&M);  
  106.     memset(p,0,sizeof(p));  
  107.     forint i = 0; i < 30; i++)  
  108.          root.next[i] = NULL;  
  109.     num = 0;  
  110.     int cnum = 1;  
  111.     strcpy(st[0] + 1, root.word+1);  
  112.     forint i = 1; i <= N; i++)  
  113.     {  
  114.        scanf("%s",st[i]+1);  
  115.        insert(&root, st[i]);  
  116.     }  
  117.     d = 1;  
  118.     printf("Case #%d:\n", abc++);  
  119.     forint i = 1; i <= M; i++)  
  120.     {  
  121.        vnum = 0;  
  122.        flag = 0;  
  123.        fuck = 0;  
  124.        scanf("%s%d",str+1, &d);  
  125.        sfind(&root, str, d);   
  126.        printf("%d\n", fuck);  
  127.     }  
  128.   }  
  129.   return 0;  
  130. }  

自己写的版本,比较容易理解


[cpp]  view plain copy
  1. #include   
  2. #include   
  3. #include   
  4. #include   
  5. #include   
  6. #include   
  7. #include   
  8. #include   
  9. #include   
  10. #include   
  11.   
  12. using namespace std;  
  13. #define MAXEDIT 15  
  14. class node {  
  15. public:  
  16.     string word;  
  17.     node *next[MAXEDIT];  
  18.   
  19.     node() {  
  20.          
  21.         memset(next, 0, sizeof(next));  
  22.     }  
  23. };  
  24.   
  25. string split(const string& str) {  
  26.     size_t pos = str.find(" ||| ");  
  27.     return str.substr(0, pos);  
  28. }  
  29.   
  30. bool isalpha(const string& str) {  
  31.     for (int i = 0; i < str.size(); ++i) {  
  32.         if (!(str[i]>='a' && str[i] <='z' || str[i]>='A' && str[i] <='Z' )) return false;  
  33.     }  
  34.     return true;  
  35. }  
  36.   
  37. int minTri(int a, int b, int c) {  
  38.     int rst = a;  
  39.     if (rst > b) rst = b;  
  40.     if (rst > c) rst = c;  
  41.   
  42.     return rst;  
  43. }  
  44.   
  45. int editDist(const string &str1, const string &str2) {  
  46.     vectorint> > mat(str1.size() + 1, vector<int>(str2.size() +1, 0));  
  47.     for (int i = 1; i < str1.size(); ++i) mat[i][0] = i;  
  48.     for (int i = 1; i < str2.size(); ++i) mat[0][i] = i;  
  49.   
  50.     for (int i = 1; i <= str1.size(); ++i) {  
  51.         for (int j = 1; j <= str2.size(); ++j) {  
  52.             int cost = 1;  
  53.             if (str1[i-1] == str2[j-1]) cost = 0;  
  54.   
  55.             mat[i][j] = minTri(mat[i-1][j-1]+cost, mat[i-1][j] + 1, mat[i][j-1] + 1);  
  56.         }  
  57.     }  
  58.   
  59.     return mat[str1.size()][str2.size()];  
  60. }  
  61.   
  62. void insert(node* head, const string& str) {  
  63.     node *tmp = head;  
  64.     while (tmp) {  
  65.         int dis = editDist(tmp->word, str);  
  66.         if (dis == 0 || dis >= MAXEDIT) return;  
  67.         if (tmp->next[dis]) tmp = tmp->next[dis];  
  68.         else {  
  69.             tmp->next[dis] = new node();  
  70.             tmp->next[dis]->word = str;  
  71.             break;  
  72.         }  
  73.     }  
  74.      
  75. }  
  76.   
  77. void buildKDTree(node *head, const vector& ls) {  
  78.     for (int i = 0; i < ls.size(); ++i) {  
  79.         insert(head, ls[i]);  
  80.     }  
  81. }  
  82.   
  83. void freeKDTree(node* head) {  
  84.     for (int i = 0; i < MAXEDIT; ++i) {  
  85.         if (head->next[i]) {  
  86.             freeKDTree(head->next[i]);  
  87.             delete head->next[i];  
  88.             head->next[i] = NULL;  
  89.         }  
  90.     }  
  91. }  
  92.   
  93. void findN(node *head, const string & str,vectorint> >& rst, int n) {  
  94.     int d = editDist(head->word, str);  
  95.     if (d <= n && d != 0) {  
  96.         rst.push_back(make_pair(head->word,d));  
  97.     }  
  98.     int minR = max(1, d - n);  
  99.     int maxR = min(MAXEDIT-1, d + n);  
  100.     for (int i = minR; i <= maxR; ++i) {  
  101.         if (head->next[i]) {  
  102.             findN(head->next[i], str, rst, n);  
  103.         }  
  104.     }  
  105. }  
  106.   
  107. bool Cmp(const pairint>& p1, const pairint> &p2) {  
  108.     return p1.second < p2.second;  
  109. }  
  110.   
  111. int main(int argc, char *argv[]) {  
  112.       
  113.       
  114.     if (argc != 3) {  
  115.         cout << "input output"<
  116.         return -1;  
  117.     }  
  118.     ifstream fin(argv[1]);  
  119.     ofstream fo(argv[2]);  
  120.       
  121.     string line;  
  122.       
  123.     set st;  
  124.     while(getline(fin, line)) {  
  125.         string word = split(line);  
  126.         if (isalpha(word) && word.size() > 1)  
  127.             st.insert(word);  
  128.     }  
  129.       
  130.     vector ls(st.size());  
  131.     set::iterator it = st.begin();  
  132.     int i = 0;  
  133.     for(; it != st.end(); ++it)  
  134.         ls[i++] = *it;  
  135.   
  136.     node head;  
  137.     head.word = ls[0];  
  138.     buildKDTree(&head, ls);  
  139.   
  140.     for (i = 0; i < ls.size();++i) {  
  141.         if ((i+1)%5000 ==0) cout << i+1<
  142.         vectorint> > rst;  
  143.         int dist = min((int)ls[i].size()/2, 3);  
  144.         findN(&head, ls[i], rst, dist);  
  145.         ostringstream ostr;  
  146.         ostr<"\t";  
  147.         sort(rst.begin(), rst.end(), Cmp);  
  148.         for (int j = 0; j < rst.size(); ++j) {  
  149.             ostr<" ";  
  150.         }  
  151.         fo<
  152.     }  
  153.   
  154.     freeKDTree(&head);  
  155.   
  156.     fin.close();  
  157.     fo.close();  
  158.     system("pause");  
  159.    return 0;  
  160. }  

你可能感兴趣的:(BK树)