poj 3415Common Substrings (后缀自动机)

poj 3415Common Substrings (后缀自动机)

题意:给出两个串,问这两个串的所有的子串中(重复出现的,只要是位置不同就算两个子串),长度大于等于k的公共子串有多少个。

解题思路:第一个真正意义上独立完成的后缀自动机。。我们这样做,先给第一个串建好sam,然后用第二个串去sam上匹配,匹配过程中,记录temp,表示s2匹配到当前位置时,能匹配的最大长度,假设此时匹配到sam上的位置是p。这时,我们可以得到的长度大于等于k的,在s2中以当前位置结束的公共子串的个数,为temp - max ( k , p->fa->len + 1 ) + 1对吧,但是这样的串在s1中会出现在哪些位置呢?那就是p的right集合了。所以我们先处理出每个节点的right集合的大小,这里我用cnt数组记录。那么s2[i]这个位置,与s1的p状态能匹配的大于等于k的公共子串的个数就是 ( temp - max ( k , p->fa->len + 1 ) + 1 ) * cnt[p] 了。注意到,我们这里用k与p->fa->len + 1去取了一个max,这是为什么呢?因为我当前只统计了p状态。所以我们还有统计p沿着fa边走的状态,起初我是每次都沿着fa边往上走,然后每次做一下统计,TLE了。。改进,如果k <= p->fa->len , 那么我们必然还可以从父亲那里匹配出若干个符合要求的子串,而且不管是从哪个儿子过来的,fa处能增加的个数都是一样的。因此我们像线段树延迟标记那样,给父亲节点打上一个标记,当s2枚举完了之后,在根据拓扑序,将延迟标记dp往上推就好了。

(有问题请留言详细讨论)

#include<stdio.h>
#include<string.h>
#include<algorithm>
#define ll __int64
using namespace std ;

const int maxn = 111111 ;
int fa[maxn<<1] , c[52][maxn<<1] , val[maxn<<1] ;
int tot , last ;
int max ( int a , int b ) { return a > b ? a : b ; }
inline int get ( char c ) {
	if ( c >= 'a' && c <= 'z' ) return c - 'a' ;
	else return c - 'A' + 26 ;
}

inline int new_node ( int step ) {
	int i ;
	val[++tot] = step ;
	for ( i = 0 ; i < 52 ; i ++ ) c[i][tot] = 0 ;
	fa[tot] = 0 ;
	return tot ;
}

void add ( int k ) {
	int p = last , i ;
	int np = new_node ( val[p] + 1 ) ;
	while ( p && !c[k][p] ) c[k][p] = np , p = fa[p] ;
	if ( !p ) fa[np] = 1 ;
	else {
		int q = c[k][p] ;
		if ( val[q] == val[p] + 1 ) fa[np] = q ;
		else {
			int nq = new_node ( val[p] + 1 ) ;
			for ( i = 0 ; i < 52 ; i ++ ) c[i][nq] = c[i][q] ;
			fa[nq] = fa[q] ;
			fa[np] = fa[q] = nq ;
			while ( p && c[k][p] == q ) c[k][p] = nq , p = fa[p] ;
		}
	}
	last = np ;
}

int ws[maxn<<1] , pos[maxn<<1] , cnt[maxn<<1] , col[maxn<<1] ;
void Sort () {
	int i ;
	for ( i = 1 ; i <= tot ; i ++ ) ws[i] = 0 ;
	for ( i = 1 ; i <= tot ; i ++ ) ws[val[i]] ++ ;
	for ( i = 1 ; i <= tot ; i ++ ) ws[i] += ws[i-1] ;
	for ( i = 1 ; i <= tot ; i ++ ) pos[ws[val[i]]--] = i ;
}

void cal ( char *s ) {
	int p = 1 , i , len = strlen ( s ) ;
	for ( i = 1 ; i <= tot ; i ++ ) cnt[i] = col[i] = 0 ;
	for ( i = 0 ; i < len ; i ++ ) {
		int k = get ( s[i] ) ;
		cnt[c[k][p]] ++ ;
		p = c[k][p] ;
	}
	for ( i = tot ; i >= 1 ; i -- ) {
		int q = pos[i] ;
		cnt[fa[q]] += cnt[q] ;
	}
}

ll ans ;
void solve ( char *s , int k ) {
	int i , len = strlen ( s ) ;
	int temp = 0 , p = 1 ;
	ans = 0 ;
	for ( i = 0 ; i < len ; i ++ ) {
		int d = get ( s[i] ) ;
		if ( c[d][p] ) {
			temp ++ ;
			p = c[d][p] ;
		}
		else {
			while ( p && !c[d][p] ) p = fa[p] ;
			if ( !p ) temp = 0 , p = 1 ;
			else temp = val[p] + 1 , p = c[d][p] ;
		}
		int q = p , fuck = temp ;
		if ( fuck >= k ) {
			ans += (ll) ( fuck - max ( k , val[fa[p]] + 1 ) + 1 ) * cnt[p] ;
			if ( k <= val[fa[p]] ) col[fa[p]] ++ ;
		}
	//	printf ( "i = %d , ans = %I64d\n" , i , ans ) ;
	}
	for ( i = tot ; i >= 1 ; i -- ) {
		p = pos[i] ;
		ans += (ll) col[p] * ( val[p] - max ( k , val[fa[p]] + 1 ) + 1 ) * cnt[p] ;
//		if ( i == 2 ) printf ( "val = %d , fa = %d , cnt = %d\n" , val[p] , val[fa[p]] + 1 , cnt[p] ) ;
		if ( k <= val[fa[p]] ) col[fa[p]] += col[p] ;
	//	printf ( "i = %d , ans = %I64d , col = %d\n" , i , ans , col[i] ) ;
	}
}

void init () {
	tot = 0 ;
	last = new_node ( 0 ) ;
}

char s[maxn] ;
int k ;

int main () {
	while ( scanf ( "%d" , &k ) != EOF ) {
		if ( k == 0 ) break ;
		init () ;
		scanf ( "%s" , s ) ;
		int i , len = strlen ( s ) ;
		for ( i = 0 ; i < len ; i ++ ) add ( get ( s[i] ) ) ;
		Sort () ;
		cal ( s ) ;
		scanf ( "%s" , s ) ;
		solve ( s , k ) ;
		printf ( "%I64d\n" , ans ) ;
	}
}


你可能感兴趣的:(后缀自动机)