地址:http://poj.org/problem?id=3415
题意:给你两个字符串,还有一个数字K,要求这两个字符串长度大于等于K的相同子串对数,具体看题目
分析:这题求相同子串,自然就会让人想到后缀数组之类的解法,不过后缀数组只能求出最长的公共子串,还有一些公共前缀的信息,没办法完成统计子串对数,我想了半天都没想出办法,然后看了一眼讨论,有人说是单调性。。。我就继续往这方面思考了,仔细研究了下高度数组height,你会发现,当height[ i ]> heigh[ i-1 ]时,它跟前面的公共前缀不可能大于height[ i ]了,也就是对于j<i,如果height[ j ]> height[ i ],那么height[ j ]-height[ i ]这部分加上公共部分height[ i ]组成的相同子串不会与后面的相同了,也就是说,这部分可以独自统计,而剩下的部分height[ i ]可以与i一起算就行。我们维护一个桟,保证桟中的高度不断递增,当出现一个高度比桟顶元素低的,说明栈顶元素比他高的部分组成的子串可以统计了,统计完后就退桟,知道桟为空,或者栈顶元素低于当前元素,这里需要注意的地方就是统计时统计的部分要认真考虑,我在这里错了,导致wa了n次T_T
看不明白的话,就看下面的图吧:
代码:
/** head files*/ #include <cstdlib> #include <cctype> #include <cstring> #include <cstdio> #include <cmath> #include <algorithm> #include <vector> #include <string> #include <iostream> #include <sstream> #include <map> #include <set> #include <queue> #include <stack> #include <fstream> #include <numeric> #include <iomanip> #include <bitset> #include <list> #include <stdexcept> #include <functional> #include <utility> #include <ctime> using namespace std; /** some operate*/ #define PB push_back #define MP make_pair #define REP(i,n) for(i=0;i<(n);++i) #define UPTO(i,l,h) for(i=(l);i<=(h);++i) #define DOWN(i,h,l) for(i=(h);i>=(l);--i) #define MSET(arr,val) memset(arr,val,sizeof(arr)) #define MAX3(a,b,c) max(a,max(b,c)) #define MAX4(a,b,c,d) max(max(a,b),max(c,d)) #define MIN3(a,b,c) min(a,min(b,c)) #define MIN4(a,b,c,d) min(min(a,b),min(c,d)) /** some const*/ #define N 222222 #define M 222222 #define PI acos(-1.0) #define oo 1111111111 /** some alias*/ typedef long long ll; /** Global variables*/ /** some template names, just push ctrl+j to get it in*/ //manacher 求最长回文子串 //pqueue 优先队列 //combk n元素序列的第m小的组合和 //pmatrix n个点的最大子矩阵 //suffixarray 后缀数组 template <typename T, int LEN> struct suffixarray { int str[LEN*3],sa[LEN*3]; int rank[LEN],height[LEN]; int id[LEN]; int len; bool equal(int *str, int a, int b) { return str[a]==str[b]&&str[a+1]==str[b+1]&&str[a+2]==str[b+2]; } bool cmp3(int *str, int *nstr, int a, int b) { if(str[a]!=str[b])return str[a]<str[b]; if(str[a+1]!=str[b+1])return str[a+1]<str[b+1]; return nstr[a+b%3]<nstr[b+b%3]; } void radixsort(int *str, int *sa, int *res, int n, int m) { int i; REP(i,m)id[i]=0; REP(i,n)++id[str[sa[i]]]; REP(i,m)id[i+1]+=id[i]; DOWN(i,n-1,0)res[--id[str[sa[i]]]]=sa[i]; } void dc3(int *str, int *sa, int n, int m) { #define F(x) ((x)/3+((x)%3==1?0:one)) #define G(x) ((x)<one?(x)*3+1:((x)-one)*3+2) int *nstr=str+n, *nsa=sa+n, *tmpa=rank, *tmpb=height; int i,j,k,len=0,num=0,zero=0,one=(n+1)/3; REP(i,n)if(i%3)tmpa[len++]=i; str[n]=str[n+1]=0; radixsort(str+2, tmpa, tmpb, len, m); radixsort(str+1, tmpb, tmpa, len, m); radixsort(str+0, tmpa, tmpb, len, m); nstr[F(tmpb[0])]=num++; UPTO(i,1,len-1) nstr[F(tmpb[i])]=equal(str,tmpb[i-1],tmpb[i])?num-1:num++; if(num<len)dc3(nstr,nsa,len,num); else REP(i,len)nsa[nstr[i]]=i; if(n%3==1)tmpa[zero++]=n-1; REP(i,len)if(nsa[i]<one)tmpa[zero++]=nsa[i]*3; radixsort(str, tmpa, tmpb, zero, m); REP(i,len)tmpa[nsa[i]=G(nsa[i])]=i; i=j=0; REP(k,n) if(j>=len||(i<zero&&cmp3(str,tmpa,tmpb[i],nsa[j])))sa[k]=tmpb[i++]; else sa[k]=nsa[j++]; } void initSA(T *s, int n,int m) { int i,j,k=0; str[len=n]=0; REP(i,n)str[i]=s[i]; dc3(str,sa,n+1,m); REP(i,n)sa[i]=sa[i+1]; REP(i,n)rank[sa[i]]=i; REP(i,n) { if(k)--k; if(rank[i])for(j=sa[rank[i]-1];str[i+k]==str[j+k];++k); else k=0; height[rank[i]]=k; } } }; suffixarray<char,N> msa; stack<int> stk; char s[N],tmp[N]; int sum[N][2],suma,sumb; int main() { int i,l,r,k,n,m,now,left,low; ll ans; while(~scanf("%d",&k)) { if(k==0)break; scanf("%s%s",s,tmp); s[n=strlen(s)]=1; m=strlen(tmp); REP(i,m)s[n+i+1]=tmp[i]; m=n+m+1; msa.initSA(s,m,256); REP(i,m)MSET(sum[i],0); REP(i,m) { if(msa.sa[i]<n)++sum[i][0]; if(msa.sa[i]>n)++sum[i][1]; sum[i+1][0]+=sum[i][0]; sum[i+1][1]+=sum[i][1]; } while(!stk.empty())stk.pop(); ans=0; msa.height[m]=0; left=0; REP(i,m) { if(!stk.empty()) { r=stk.top(); while(!stk.empty()&&msa.height[stk.top()]>=msa.height[i+1]) { now=msa.height[stk.top()]; stk.pop(); if(stk.empty())l=left,low=k-1; else l=stk.top(),low=msa.height[stk.top()]; low=max(low,msa.height[i+1]); suma=sum[r][0]-sum[l-1][0]; sumb=sum[r][1]-sum[l-1][1]; ans+=(ll)suma*sumb*(now-low); } } if(msa.height[i+1]>=k)stk.push(i+1); else left=i+1; } cout<<ans<<endl; } return 0; }