大致题意:
给出一个字符串,求出所有不重叠出现次数大于等于两次的子串的数目。
大致思路:
后缀数组的好题,思想很妙也很难想到。大致过程就是先枚举子串的长度tmp,对于每一个height值大于等于tmp的区间内找到sa的最大值和最小值,看他们之间的距离是否小于tmp,是的话则ans++;
#include<iostream> #include<cstdio> #include<cstring> using namespace std; const int inf=1<<30; const int nMax=500000; int num[nMax]; int sa[nMax], rank[nMax], height[nMax]; int wa[nMax], wb[nMax], wv[nMax], wd[nMax]; int cmp(int *r, int a, int b, int l){ return r[a] == r[b] && r[a+l] == r[b+l]; } void da(int *r, int n, int m){ // 倍增算法 r为待匹配数组 n为总长度 m为字符范围 int i, j, p, *x = wa, *y = wb, *t; for(i = 0; i < m; i ++) wd[i] = 0; for(i = 0; i < n; i ++) wd[x[i]=r[i]] ++; for(i = 1; i < m; i ++) wd[i] += wd[i-1]; for(i = n-1; i >= 0; i --) sa[-- wd[x[i]]] = i; for(j = 1, p = 1; p < n; j *= 2, m = p){ for(p = 0, i = n-j; i < n; i ++) y[p ++] = i; for(i = 0; i < n; i ++) if(sa[i] >= j) y[p ++] = sa[i] - j; for(i = 0; i < n; i ++) wv[i] = x[y[i]]; for(i = 0; i < m; i ++) wd[i] = 0; for(i = 0; i < n; i ++) wd[wv[i]] ++; for(i = 1; i < m; i ++) wd[i] += wd[i-1]; for(i = n-1; i >= 0; i --) sa[-- wd[wv[i]]] = y[i]; for(t = x, x = y, y = t, p = 1, x[sa[0]] = 0, i = 1; i < n; i ++){ x[sa[i]] = cmp(y, sa[i-1], sa[i], j) ? p - 1: p ++; } } } void calHeight(int *r, int n){ // 求height数组。 int i, j, k = 0; for(i = 1; i <= n; i ++) rank[sa[i]] = i; for(i = 0; i < n; height[rank[i ++]] = k){ for(k ? k -- : 0, j = sa[rank[i]-1]; r[i+k] == r[j+k]; k ++); } } int loc[nMax]; char str[nMax],res[nMax]; int abs(int a){ if(a>0)return a; return -a; } int getnum(int tmp,int len){ //长度为tmp的重复子串共有多少个 int i,maxl=-1,minl=inf,res=0; for(i=2;i<=len;i++){ if(height[i]>=tmp){ maxl=max(sa[i-1],max(sa[i],maxl)); minl=min(sa[i-1],min(sa[i],minl)); } else{ if(maxl!=-1&&minl!=inf&&maxl-minl>=tmp){ res++; } maxl=-1; minl=inf; } } if(maxl!=-1&&minl!=inf&&maxl-minl>=tmp)res++; return res; } int main(){ int n,i,sp,ans; while(scanf("%s",str)&&str[0]!='#'){ n=strlen(str); sp=30; for(i=0;str[i];i++){ num[i]=str[i]-'a'+1; } num[n]=0; da(num,n+1,sp+4); calHeight(num,n); ans=0; for(i=1;i<=n/2;i++){ // cout<<"getnum "<<i<<"="<<getnum(i,n)<<endl; ans+=getnum(i,n); } printf("%d\n",ans); } return 0; }