poj 3376 Finding Palindromes(扩展kmp+trie)

poj 3376 Finding Palindromes(扩展kmp+trie)

题意:给出n个字符串,问这n个字符串两两链接(一共有n^2中连接方法),组成的所有的字符串中,有多少个回文串。

好题!!!

解题思路:对于两个串连接是否是回文串,我们应该怎样去判断了?假如我们把其中一个翻转,若此时,短的那个串是长的那个的前缀,而长的那个串后面剩余的后缀恰好是个回文串,那这两个串连起来就是个回文串了。比如abc和abacba连接,我们把后一个串翻转,得到abcaba,abc为其前缀,而aba是个回文串,那么连起来就是个回文串了(这个规律找到就好办了)。那我们就把所有的串插入到trie中,然后再用所有的反串去匹配就行了。匹配的过程中,走到任意一个节点,而这个节点有可能是若干个串的结尾,那么此时我们就要判反串匹配位置下面剩余的部分是否回文。如果是的,ans就加上以这个节点为结尾的原串的个数(这个插入的时候就可以统计进去了)。如果走完了,还没走到叶子节点,那么就要看走到的节点下的子树(其实是以前面走过的路径为前缀的字符串剩下的一些后缀)有多少是回文的了(这个先预处理所有的串的后缀有哪些是回文的,然后在插入的时候统计到节点上)。剩下来一个问题就是如何在线性的时间内(或许o(nlogn)也可以吧,但我们有线性的算法,岂不更好?),这里我只是说,用扩展kmp能很合适。具体如何实现,留个小思考给大家(很简单的啦)。。

#include<stdio.h>
#include<algorithm>
#include<string.h>
#define ll __int64
#include<vector>
using namespace std ;

const int maxn = 2222222 ;
char vec[maxn] ;
int g[maxn] , nxt[maxn] ;
bool li[maxn] ;
int ok[maxn] , p[maxn] ;
void get_p (const char *T){
     int len=strlen(T),a=0;
     int i , k ;
     p[0]=len;
     while(a<len-1 && T[a]==T[a+1]) a++;
     p[1]=a;
     a=1;
     for( k=2;k<len;k++){
         int fuck=a+p[a]-1,L=p[k-a];
         if( (k-1)+L >= fuck){
             int j = (fuck-k+1)>0 ? (fuck-k+1) : 0;
             while(k+j<len && T[k+j]==T[j]) j++;
             p[k]=j;
             a=k;
         }
         else p[k]=L;
     }
}

void match ( char *s , char *s1 ) {
    int len = strlen ( s ) , len1 = strlen ( s1 ) ;
    int i = 0 , k , j = 0 , a ;
    while ( i < len && j < len1 && s[i] == s1[j] ) i ++ , j ++ ;
    ok[0] = j ;
    a = 0 ;
    for ( k = 1 ; k < len ; k ++ ) {
        int fuck = a + ok[a] - 1 , l = p[k-a] ;
        if ( k + l - 1 >= fuck ) {
            int j = ( fuck - k + 1 ) > 0 ? ( fuck - k + 1 ) : 0 ;
            while ( k + j < len && j < len1 && s[k+j] == s1[j] ) j ++ ;
            ok[k] = j ;
            a = k ;
        }
        else ok[k] = l ;
    }
}

int tot = 0 , c[26][maxn] , cnt[maxn] , val[maxn] ;

int new_node () {
    int i ;
    for ( i = 0 ; i < 26 ; i ++ ) c[i][tot] = 0 ;
    cnt[tot] = val[tot] = 0 ;
    return tot ++ ;
}

void insert ( char *s ) {
    int len = strlen ( s ) , i , now = 0 ;
    for ( i = 0 ; i < len ; i ++ ) {
        int k = s[i] - 'a' ;
        if ( !c[k][now] ) c[k][now] = new_node () ;
        now = c[k][now] ;
        if ( i + 1 < len && ok[i+1] == len - i - 1 ) {
            cnt[now] ++ ;
        }
    }
	cnt[now] ++ ;
    val[now] ++ ;
}

ll ans = 0 ;

void cal ( int len ) {
    int j , i , now = 0 ;
	li[len+1] = 1 ;
//	printf ( "len = %d\n" , len ) ;
//	for ( i = 1 ; i <= len ; i ++ ) printf ( "%d " , nxt[i] ) ; puts ( "" ) ;
    for ( j = 1 ; j <= len ; j ++ ) {
	//	printf ( "j = %d , ans = %I64d\n" , j , ans ) ;
		if ( li[j] ) now = 0 ;
        int k = vec[j] - 'a' ;
        if ( !c[k][now] ) {
            now = 0 ;
	//		printf ( "nxt[%d] = %d\n" , j , nxt[j] ) ;
			j = nxt[j] - 1 ;
			continue ;
        }
        now = c[k][now] ;
        if ( !li[j+1] && g[j+1] ) ans += (ll) val[now] ;
//		printf ( "j = %d , now = %d\n" , j , now ) ;
		if ( li[j+1] ) {
	//		if ( j == 10 ) printf ( "cnt[%d] = %d\n" , now , cnt[now] ) ;
			ans += (ll) cnt[now] ;
			now = 0 ;
		}
    }
}

char s1[maxn] , s[maxn] ;

int main () {
    int n , i , j , k ;
    while ( scanf ( "%d" , &n ) != EOF ) {
        tot = 0 ;
        new_node () ;
		int t = 0 ;
        for ( i = 1 ; i <= n ; i ++ ) {
            scanf ( "%d%s" , &j , s ) ;
            strcpy ( s1 , s ) ;
            int len = strlen ( s ) ;
            reverse ( s1 , s1 + len ) ;
            get_p ( s1 ) ;
            match ( s , s1 ) ;
            insert ( s ) ;
            get_p ( s ) ;
            match ( s1 , s ) ;
			li[t+1] = 1 ;
            for ( j = 0 ; j < len ; j ++ ) {
                if ( ok[j] == len - j ) g[++t] = 1 ;
                else g[++t] = 0 ;
                vec[t] = s1[j] ;
				if ( j ) li[t] = 0 ;
            }
        }
    //    for ( i = 1 ; i < t ; i ++ ) printf ( "%d " , cnt[i] ) ; puts ( "" ) ;
		int last = t + 1 ;
	//	for ( i = 1 ; i <= t ; i ++ ) printf ( "%d " , li[i] ) ; puts ( "" ) ;
		for ( i = t ; i >= 1 ; i -- ) {
			nxt[i] = last ;
			if ( li[i] ) last = i ;
		}
//		for ( i = 1 ; i <= t ; i ++ ) printf ( "%d " , nxt[i] ) ; puts ( "" ) ;
        ans = 0 ;
        cal ( t ) ;
        printf ( "%I64d\n" , ans ) ;
    }
}


你可能感兴趣的:(扩展kmp)