【codeforces】2014 Asia Xian Regional Contest G The Problem to Slow Down You 【Palindromic Tree】

传送门:【codeforces】2014 Asia Xian Regional Contest G The Problem to Slow Down You 【Palindromic Tree】


题目分析:我们对两个字符串分别建立回文树,然后分别从节点0和1开始dfs,如果两个串都可以往某一个走,则ans+=两个串该节点下cnt的乘积,然后沿着这个方向继续dfs。


代码如下:


#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std ;

typedef long long LL ;

#define rep( i , a , b ) for ( int i = ( a ) ; i <  ( b ) ; ++ i )
#define For( i , a , b ) for ( int i = ( a ) ; i <= ( b ) ; ++ i )
#define rev( i , a , b ) for ( int i = ( a ) ; i >= ( b ) ; -- i )
#define clr( a , x ) memset ( a , x , sizeof a )

const int MAXN = 200005 ;
const int N = 26 ;

struct Palindromic_Tree {
	int next[MAXN][N] ;
	int fail[MAXN] ;
	int cnt[MAXN] ;
	int len[MAXN] ;
	int S[MAXN] , n ;
	int last ;
	int p ;

	int newnode ( int l ) {
		rep ( i , 0 , N ) next[p][i] = 0 ;
		cnt[p] = 0 ;
		len[p] = l ;
		return p ++ ;
	}

	void init () {
		p = 0 ;
		newnode (  0 ) ;
		newnode ( -1 ) ;
		last = 0 ;
		n = 0 ;
		S[n] = -1 ;
		fail[0] = 1 ;
	}

	int get_fail ( int x ) {
		while ( S[n - len[x] - 1] != S[n] ) x = fail[x] ;
		return x ;
	}

	void add ( int c ) {
		c -= 'a' ;
		S[++ n] = c ;
		int cur = get_fail ( last ) ;
		if ( !next[cur][c] ) {
			int now = newnode ( len[cur] + 2 ) ;
			fail[now] = next[get_fail ( fail[cur] )][c] ;
			next[cur][c] = now ;
		}
		last = next[cur][c] ;
		cnt[last] ++ ;
	}

	void count () {
		rev ( i , p - 1 , 0 ) cnt[fail[i]] += cnt[i] ;
	}
} ;

Palindromic_Tree T1 , T2 ;
char s1[MAXN] , s2[MAXN] ;
int n1 , n2 ;
LL ans ;

void dfs ( int u , int v ) {
	rep ( i , 0 , 26 ) {
		int x = T1.next[u][i] , y = T2.next[v][i] ;
		if ( x && y ) {
			ans += ( LL ) T1.cnt[x] * T2.cnt[y] ;
			dfs ( x , y ) ;
		}
	}
}

void solve () {
	ans = 0 ;
	T1.init () ;
	T2.init () ;
	scanf ( "%s%s" , s1 , s2 ) ;
	n1 = strlen ( s1 ) ;
	n2 = strlen ( s2 ) ;
	rep ( i , 0 , n1 ) T1.add ( s1[i] ) ;
	rep ( i , 0 , n2 ) T2.add ( s2[i] ) ;
	T1.count () ;
	T2.count () ;
	dfs ( 0 , 0 ) ;
	dfs ( 1 , 1 ) ;
	printf ( "%I64d\n" , ans ) ;
}

int main () {
	int T ;
	scanf ( "%d" , &T ) ;
	For ( cas , 1 , T ) {
		printf ( "Case #%d: " , cas ) ;
		solve () ;
	}
	return 0 ;
}


你可能感兴趣的:(codeforces)