poj 3376 Finding Palindromes(扩展kmp+trie)

 

poj 3376 Finding Palindromes(扩展kmp+trie)

分类: 字符串   89人阅读  评论(0)  收藏  举报
扩展kmp

poj 3376 Finding Palindromes(扩展kmp+trie)

题意:给出n个字符串,问这n个字符串两两链接(一共有n^2中连接方法),组成的所有的字符串中,有多少个回文串。

好题!!!

解题思路:对于两个串连接是否是回文串,我们应该怎样去判断了?假如我们把其中一个翻转,若此时,短的那个串是长的那个的前缀,而长的那个串后面剩余的后缀恰好是个回文串,那这两个串连起来就是个回文串了。比如abc和abacba连接,我们把后一个串翻转,得到abcaba,abc为其前缀,而aba是个回文串,那么连起来就是个回文串了(这个规律找到就好办了)。那我们就把所有的串插入到trie中,然后再用所有的反串去匹配就行了。匹配的过程中,走到任意一个节点,而这个节点有可能是若干个串的结尾,那么此时我们就要判反串匹配位置下面剩余的部分是否回文。如果是的,ans就加上以这个节点为结尾的原串的个数(这个插入的时候就可以统计进去了)。如果走完了,还没走到叶子节点,那么就要看走到的节点下的子树(其实是以前面走过的路径为前缀的字符串剩下的一些后缀)有多少是回文的了(这个先预处理所有的串的后缀有哪些是回文的,然后在插入的时候统计到节点上)。剩下来一个问题就是如何在线性的时间内(或许o(nlogn)也可以吧,但我们有线性的算法,岂不更好?),这里我只是说,用扩展kmp能很合适。具体如何实现,留个小思考给大家(很简单的啦)。。

[cpp]  view plain copy print ?
  1. #include<stdio.h>  
  2. #include<algorithm>  
  3. #include<string.h>  
  4. #define ll __int64  
  5. #include<vector>  
  6. using namespace std ;  
  7.   
  8. const int maxn = 2222222 ;  
  9. char vec[maxn] ;  
  10. int g[maxn] , nxt[maxn] ;  
  11. bool li[maxn] ;  
  12. int ok[maxn] , p[maxn] ;  
  13. void get_p (const char *T){  
  14.      int len=strlen(T),a=0;  
  15.      int i , k ;  
  16.      p[0]=len;  
  17.      while(a<len-1 && T[a]==T[a+1]) a++;  
  18.      p[1]=a;  
  19.      a=1;  
  20.      for( k=2;k<len;k++){  
  21.          int fuck=a+p[a]-1,L=p[k-a];  
  22.          if( (k-1)+L >= fuck){  
  23.              int j = (fuck-k+1)>0 ? (fuck-k+1) : 0;  
  24.              while(k+j<len && T[k+j]==T[j]) j++;  
  25.              p[k]=j;  
  26.              a=k;  
  27.          }  
  28.          else p[k]=L;  
  29.      }  
  30. }  
  31.   
  32. void match ( char *s , char *s1 ) {  
  33.     int len = strlen ( s ) , len1 = strlen ( s1 ) ;  
  34.     int i = 0 , k , j = 0 , a ;  
  35.     while ( i < len && j < len1 && s[i] == s1[j] ) i ++ , j ++ ;  
  36.     ok[0] = j ;  
  37.     a = 0 ;  
  38.     for ( k = 1 ; k < len ; k ++ ) {  
  39.         int fuck = a + ok[a] - 1 , l = p[k-a] ;  
  40.         if ( k + l - 1 >= fuck ) {  
  41.             int j = ( fuck - k + 1 ) > 0 ? ( fuck - k + 1 ) : 0 ;  
  42.             while ( k + j < len && j < len1 && s[k+j] == s1[j] ) j ++ ;  
  43.             ok[k] = j ;  
  44.             a = k ;  
  45.         }  
  46.         else ok[k] = l ;  
  47.     }  
  48. }  
  49.   
  50. int tot = 0 , c[26][maxn] , cnt[maxn] , val[maxn] ;  
  51.   
  52. int new_node () {  
  53.     int i ;  
  54.     for ( i = 0 ; i < 26 ; i ++ ) c[i][tot] = 0 ;  
  55.     cnt[tot] = val[tot] = 0 ;  
  56.     return tot ++ ;  
  57. }  
  58.   
  59. void insert ( char *s ) {  
  60.     int len = strlen ( s ) , i , now = 0 ;  
  61.     for ( i = 0 ; i < len ; i ++ ) {  
  62.         int k = s[i] - 'a' ;  
  63.         if ( !c[k][now] ) c[k][now] = new_node () ;  
  64.         now = c[k][now] ;  
  65.         if ( i + 1 < len && ok[i+1] == len - i - 1 ) {  
  66.             cnt[now] ++ ;  
  67.         }  
  68.     }  
  69.     cnt[now] ++ ;  
  70.     val[now] ++ ;  
  71. }  
  72.   
  73. ll ans = 0 ;  
  74.   
  75. void cal ( int len ) {  
  76.     int j , i , now = 0 ;  
  77.     li[len+1] = 1 ;  
  78. //  printf ( "len = %d\n" , len ) ;  
  79. //  for ( i = 1 ; i <= len ; i ++ ) printf ( "%d " , nxt[i] ) ; puts ( "" ) ;  
  80.     for ( j = 1 ; j <= len ; j ++ ) {  
  81.     //  printf ( "j = %d , ans = %I64d\n" , j , ans ) ;  
  82.         if ( li[j] ) now = 0 ;  
  83.         int k = vec[j] - 'a' ;  
  84.         if ( !c[k][now] ) {  
  85.             now = 0 ;  
  86.     //      printf ( "nxt[%d] = %d\n" , j , nxt[j] ) ;  
  87.             j = nxt[j] - 1 ;  
  88.             continue ;  
  89.         }  
  90.         now = c[k][now] ;  
  91.         if ( !li[j+1] && g[j+1] ) ans += (ll) val[now] ;  
  92. //      printf ( "j = %d , now = %d\n" , j , now ) ;  
  93.         if ( li[j+1] ) {  
  94.     //      if ( j == 10 ) printf ( "cnt[%d] = %d\n" , now , cnt[now] ) ;  
  95.             ans += (ll) cnt[now] ;  
  96.             now = 0 ;  
  97.         }  
  98.     }  
  99. }  
  100.   
  101. char s1[maxn] , s[maxn] ;  
  102.   
  103. int main () {  
  104.     int n , i , j , k ;  
  105.     while ( scanf ( "%d" , &n ) != EOF ) {  
  106.         tot = 0 ;  
  107.         new_node () ;  
  108.         int t = 0 ;  
  109.         for ( i = 1 ; i <= n ; i ++ ) {  
  110.             scanf ( "%d%s" , &j , s ) ;  
  111.             strcpy ( s1 , s ) ;  
  112.             int len = strlen ( s ) ;  
  113.             reverse ( s1 , s1 + len ) ;  
  114.             get_p ( s1 ) ;  
  115.             match ( s , s1 ) ;  
  116.             insert ( s ) ;  
  117.             get_p ( s ) ;  
  118.             match ( s1 , s ) ;  
  119.             li[t+1] = 1 ;  
  120.             for ( j = 0 ; j < len ; j ++ ) {  
  121.                 if ( ok[j] == len - j ) g[++t] = 1 ;  
  122.                 else g[++t] = 0 ;  
  123.                 vec[t] = s1[j] ;  
  124.                 if ( j ) li[t] = 0 ;  
  125.             }  
  126.         }  
  127.     //    for ( i = 1 ; i < t ; i ++ ) printf ( "%d " , cnt[i] ) ; puts ( "" ) ;  
  128.         int last = t + 1 ;  
  129.     //  for ( i = 1 ; i <= t ; i ++ ) printf ( "%d " , li[i] ) ; puts ( "" ) ;  
  130.         for ( i = t ; i >= 1 ; i -- ) {  
  131.             nxt[i] = last ;  
  132.             if ( li[i] ) last = i ;  
  133.         }  
  134. //      for ( i = 1 ; i <= t ; i ++ ) printf ( "%d " , nxt[i] ) ; puts ( "" ) ;  
  135.         ans = 0 ;  
  136.         cal ( t ) ;  
  137.         printf ( "%I64d\n" , ans ) ;  
  138.     }  
  139. }  

你可能感兴趣的:(字符串)