poj 3415Common Substrings (后缀自动机)
题意:给出两个串,问这两个串的所有的子串中(重复出现的,只要是位置不同就算两个子串),长度大于等于k的公共子串有多少个。
解题思路:第一个真正意义上独立完成的后缀自动机。。我们这样做,先给第一个串建好sam,然后用第二个串去sam上匹配,匹配过程中,记录temp,表示s2匹配到当前位置时,能匹配的最大长度,假设此时匹配到sam上的位置是p。这时,我们可以得到的长度大于等于k的,在s2中以当前位置结束的公共子串的个数,为temp - max ( k , p->fa->len + 1 ) + 1对吧,但是这样的串在s1中会出现在哪些位置呢?那就是p的right集合了。所以我们先处理出每个节点的right集合的大小,这里我用cnt数组记录。那么s2[i]这个位置,与s1的p状态能匹配的大于等于k的公共子串的个数就是 ( temp - max ( k , p->fa->len + 1 ) + 1 ) * cnt[p] 了。注意到,我们这里用k与p->fa->len + 1去取了一个max,这是为什么呢?因为我当前只统计了p状态。所以我们还有统计p沿着fa边走的状态,起初我是每次都沿着fa边往上走,然后每次做一下统计,TLE了。。改进,如果k <= p->fa->len , 那么我们必然还可以从父亲那里匹配出若干个符合要求的子串,而且不管是从哪个儿子过来的,fa处能增加的个数都是一样的。因此我们像线段树延迟标记那样,给父亲节点打上一个标记,当s2枚举完了之后,在根据拓扑序,将延迟标记dp往上推就好了。
(有问题请留言详细讨论)
#include<stdio.h> #include<string.h> #include<algorithm> #define ll __int64 using namespace std ; const int maxn = 111111 ; int fa[maxn<<1] , c[52][maxn<<1] , val[maxn<<1] ; int tot , last ; int max ( int a , int b ) { return a > b ? a : b ; } inline int get ( char c ) { if ( c >= 'a' && c <= 'z' ) return c - 'a' ; else return c - 'A' + 26 ; } inline int new_node ( int step ) { int i ; val[++tot] = step ; for ( i = 0 ; i < 52 ; i ++ ) c[i][tot] = 0 ; fa[tot] = 0 ; return tot ; } void add ( int k ) { int p = last , i ; int np = new_node ( val[p] + 1 ) ; while ( p && !c[k][p] ) c[k][p] = np , p = fa[p] ; if ( !p ) fa[np] = 1 ; else { int q = c[k][p] ; if ( val[q] == val[p] + 1 ) fa[np] = q ; else { int nq = new_node ( val[p] + 1 ) ; for ( i = 0 ; i < 52 ; i ++ ) c[i][nq] = c[i][q] ; fa[nq] = fa[q] ; fa[np] = fa[q] = nq ; while ( p && c[k][p] == q ) c[k][p] = nq , p = fa[p] ; } } last = np ; } int ws[maxn<<1] , pos[maxn<<1] , cnt[maxn<<1] , col[maxn<<1] ; void Sort () { int i ; for ( i = 1 ; i <= tot ; i ++ ) ws[i] = 0 ; for ( i = 1 ; i <= tot ; i ++ ) ws[val[i]] ++ ; for ( i = 1 ; i <= tot ; i ++ ) ws[i] += ws[i-1] ; for ( i = 1 ; i <= tot ; i ++ ) pos[ws[val[i]]--] = i ; } void cal ( char *s ) { int p = 1 , i , len = strlen ( s ) ; for ( i = 1 ; i <= tot ; i ++ ) cnt[i] = col[i] = 0 ; for ( i = 0 ; i < len ; i ++ ) { int k = get ( s[i] ) ; cnt[c[k][p]] ++ ; p = c[k][p] ; } for ( i = tot ; i >= 1 ; i -- ) { int q = pos[i] ; cnt[fa[q]] += cnt[q] ; } } ll ans ; void solve ( char *s , int k ) { int i , len = strlen ( s ) ; int temp = 0 , p = 1 ; ans = 0 ; for ( i = 0 ; i < len ; i ++ ) { int d = get ( s[i] ) ; if ( c[d][p] ) { temp ++ ; p = c[d][p] ; } else { while ( p && !c[d][p] ) p = fa[p] ; if ( !p ) temp = 0 , p = 1 ; else temp = val[p] + 1 , p = c[d][p] ; } int q = p , fuck = temp ; if ( fuck >= k ) { ans += (ll) ( fuck - max ( k , val[fa[p]] + 1 ) + 1 ) * cnt[p] ; if ( k <= val[fa[p]] ) col[fa[p]] ++ ; } // printf ( "i = %d , ans = %I64d\n" , i , ans ) ; } for ( i = tot ; i >= 1 ; i -- ) { p = pos[i] ; ans += (ll) col[p] * ( val[p] - max ( k , val[fa[p]] + 1 ) + 1 ) * cnt[p] ; // if ( i == 2 ) printf ( "val = %d , fa = %d , cnt = %d\n" , val[p] , val[fa[p]] + 1 , cnt[p] ) ; if ( k <= val[fa[p]] ) col[fa[p]] += col[p] ; // printf ( "i = %d , ans = %I64d , col = %d\n" , i , ans , col[i] ) ; } } void init () { tot = 0 ; last = new_node ( 0 ) ; } char s[maxn] ; int k ; int main () { while ( scanf ( "%d" , &k ) != EOF ) { if ( k == 0 ) break ; init () ; scanf ( "%s" , s ) ; int i , len = strlen ( s ) ; for ( i = 0 ; i < len ; i ++ ) add ( get ( s[i] ) ) ; Sort () ; cal ( s ) ; scanf ( "%s" , s ) ; solve ( s , k ) ; printf ( "%I64d\n" , ans ) ; } }