poj Common Substrings(后缀数组&单调队列)

Common Substrings
Time Limit: 5000MS   Memory Limit: 65536K
Total Submissions: 7082   Accepted: 2355

Description

A substring of a string T is defined as:

T( ik)= TiTi +1... Ti+k -1, 1≤ ii+k-1≤| T|.

Given two strings AB and one integer K, we define S, a set of triples (ijk):

S = {( ijk) |  kKA( ik)= B( jk)}.

You are to give the value of |S| for specific AB and K.

Input

The input file contains several blocks of data. For each block, the first line contains one integer K, followed by two lines containing strings A and B, respectively. The input file is ended by K=0.

1 ≤ |A|, |B| ≤ 105
1 ≤ K ≤ min{|A|, |B|}
Characters of A and B are all Latin letters.

Output

For each case, output an integer |S|.

Sample Input

2
aababaa
abaabaa
1
xx
xx
0

Sample Output

22
5

Source

POJ Monthly--2007.10.06, wintokk

题意

给你两个长度不超过10^5的字符串A,B。和一个数字K。K<=MIN(|A|,|B|)。现在问你A,B的公共子串中长度不小于K的有多少对。位置可以重叠。

思路:

既然涉及到公共子串容易想到的是利用后缀数组的heght数组。对于长度不小于K的公共子串。很容易想把A,B串连接起来中间用一个特殊字符隔开按height数组分组。heght不小于k的连续height分一组。这样就可以保证同组内的公共前缀(子串)长度不小于K了。然后就是统计长度不小于K的公共子串的对数了。比较好想的做法是对组内每一个B的后缀。计算它与组内的每一个A的后缀能产生多少长度不小于K的公共子串。假设B额后缀b与A的后缀a的最长公共前缀的长度为len。那么他们能产生长度不小于K的公共子串的个数为len-K+1.为什么可以这样统计呢。因为不同的后缀在字符串的起点肯定不一样。所以统计子串只要统计后缀的前缀就行了。但是这样统计的话时间复杂度为O(n^2)。明显过不了。只能换一种思维统计。对于每一个B的后缀我们知道它与前面所有A的后缀能产生多少答案。然后对于A的每一个后缀它与前面B的后缀能产生多少答案。这样我们只要扫描两次就能得到最终答案了。现在关键就是怎样快速做到对于每一个B的后缀我们知道它与前面所有A的后缀能产生多少答案了。可以知道影响两个后缀的最长公共前缀的是中间最小的heght。所以我们想到用单挑栈来维护这个信息。这样我们就把heght一个一个往栈里加。这个栈的第i个元素维护目前有num[i]个A的后缀到当前位置的he[i]最小值是q[i]。这样我们每遇到一个A的后缀那么它对后面B的贡献为height[i]-K+1.一直到它到待计算B中间有height值小于它。所以我们记前面A对后面B的贡献为base。假设有新的height加进栈会使一num[a]个q[a],num[b]个q[b]。。。。的A的后缀到后面B的heght值变成height[i]。那么base=base-(q[a]-heght[i])*num[a]-(q[b]-heght[i])*num[b]。然后遇到B的后缀加上base就行了。统计A的前面B的后缀同理。这个问题就圆满解决了。

详细见代码:

#include<algorithm>
#include<iostream>
#include<string.h>
#include<stdio.h>
using namespace std;
const int INF=0x3f3f3f3f;
const int maxn=200020;
typedef long long ll;
int sa[maxn],T1[maxn],T2[maxn],he[maxn],rk[maxn],ct[maxn],q[maxn],num[maxn];
int n,m;
char txt[maxn];
ll ans;
void getsa(char *st)
{
    int i,k,p,*x=T1,*y=T2;
    for(i=0; i<m; i++) ct[i]=0;
    for(i=0; i<n; i++) ct[x[i]=st[i]]++;
    for(i=1; i<m; i++) ct[i]+=ct[i-1];
    for(i=n-1; i>=0; i--)
        sa[--ct[x[i]]]=i;
    for(k=1,p=1; p<n; k<<=1,m=p)
    {
        for(p=0,i=n-k; i<n; i++) y[p++]=i;
        for(i=0; i<n; i++) if(sa[i]>=k) y[p++]=sa[i]-k;
        for(i=0; i<m; i++) ct[i]=0;
        for(i=0; i<n; i++) ct[x[y[i]]]++;
        for(i=1; i<m; i++) ct[i]+=ct[i-1];
        for(i=n-1; i>=0; i--) sa[--ct[x[y[i]]]]=y[i];
        for(swap(x,y),p=1,x[sa[0]]=0,i=1; i<n; i++)
            x[sa[i]]=y[sa[i-1]]==y[sa[i]]&&y[sa[i-1]+k]==y[sa[i]+k]?p-1:p++;
    }
}
void gethe(char *st)
{
    int i,j,k=0;
    for(i=0;i<n;i++) rk[sa[i]]=i;
    for(i=0;i<n-1;i++)
    {
        if(k) k--;
        j=sa[rk[i]-1];
        while(st[i+k]==st[j+k]) k++;
        he[rk[i]]=k;
    }
}
int main()
{
    int mid,i,k,tp,cnt;
    ll base;

    while(scanf("%d",&k),k)
    {
        scanf("%s",txt);
        mid=n=strlen(txt);
        txt[n]='$';
        scanf("%s",txt+n+1);
        n=strlen(txt)+1;
        m=150,ans=tp=base=0;
        getsa(txt);
        gethe(txt);
        for(i=1;i<n;i++)
        {
            if(he[i]<k)
                tp=0,base=0;//tp为栈顶。base为前面对后面的贡献
            else
            {
                cnt=0;//统计有多少A的后缀是以he[i]为最小值的。
                if(sa[i-1]<mid)//注意这里必须为sa[i-1]。因为he[i]为i和i-1的最长公共前缀。所以只能确定sa[i-1]与后面的关系
                    cnt++,base+=he[i]-k+1;
                while(tp&&he[i]<q[tp-1])
                {
                    base-=(q[tp-1]-he[i])*num[tp-1];//最小值改变。对后面的b产生影响
                    cnt+=num[tp-1];
                    tp--;
                }
                q[tp]=he[i];
                num[tp++]=cnt;
                if(sa[i]>mid)
                    ans+=base;
            }
        }
        tp=0,base=0;
        for(i=1;i<n;i++)
        {
            if(he[i]<k)
                tp=0,base=0;
            else
            {
                cnt=0;
                if(sa[i-1]>mid)
                    cnt++,base+=he[i]-k+1;
                while(tp&&he[i]<q[tp-1])
                {
                    base-=(q[tp-1]-he[i])*num[tp-1];
                    cnt+=num[tp-1];
                    tp--;
                }
                q[tp]=he[i];
                num[tp++]=cnt;
                if(sa[i]<mid)
                    ans+=base;
            }
        }
        printf("%I64d\n",ans);
    }
    return 0;
}


你可能感兴趣的:(c,算法,ACM)