题目链接:
http://uoj.ac/problem/103
题解:
一道题让我新了解到了两个算法:处理回文子串问题的manacher算法与快速求RMQ的ST算法,至于后缀数组之前学习过不过还是抄模板了,附学习资料:
manacher :
http://www.open-open.com/lib/view/open1419150233417.html
ST算法:
http://blog.csdn.net/insistgogo/article/details/9929103
后缀数组模板:
http://blog.csdn.net/deritt/article/details/50830677
本题中首先用manacher求回文子串,每求出一个回文子串就利用后缀数组的height数组与ST算法,二分寻找共有多少个相同的子串,复杂度大概是nlogn的时间复杂度,可以接受。
写了一早上,先是因为没用ST表而超时,然后是各种细节处理错误,醉了
代码:
#include<iostream>
#include<algorithm>
#include<stdio.h>
#include<string.h>
#define maxn (300005)
using namespace std;
char s[maxn];
int SA[maxn],wsf[maxn],wv[maxn],wa[maxn],wb[maxn],R[maxn],lg2[maxn],ST[maxn][20],slen;
int cmp(int *r,int a,int b,int l)
{return r[a]==r[b]&&r[a+l]==r[b+l];}
void da(int *r,int *sa,int n,int m)
{
int *x=wa,*y=wb,*t,i,j,p;
for (i=0;i<m;i++) wsf[i]=0;
for (i=0;i<n;i++)wsf[x[i]=r[i]]++;
for (i=1;i<m;i++)wsf[i]+=wsf[i-1];
for (i=n-1;i>=0;i--) sa[--wsf[x[i]]]=i;
for (j=1,p=1;p<n;j*=2,m=p)
{
for (p=0,i=n-j;i<n;i++) y[p++]=i;
for (i=0;i<n;i++) if (sa[i]-j>=0) y[p++]=sa[i]-j;
for (i=0;i<n;i++) wv[i]=x[y[i]];
for (i=0;i<m;i++) wsf[i]=0;
for (i=0;i<n;i++) wsf[wv[i]]++;
for (i=1;i<m;i++) wsf[i]+=wsf[i-1];
for (i=n-1;i>=0;i--) sa[--wsf[wv[i]]]=y[i];
for (t=x,x=y,y=t,p=1,x[sa[0]]=0,i=1;i<n;i++)
x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++;
}
return ;
}
int rank[maxn],ht[maxn];
void getht(int *r,int *sa,int n)
{
int i,j,k=0;
for (i=1;i<=n;i++) rank[sa[i]]=i;
for (i=0;i<n;ht[rank[i++]]=k)
for (k?k--:0,j=sa[rank[i]-1];r[i+k]==r[j+k];k++);
//求st表
for(i=2;i<=n;++i)lg2[i]=lg2[i>>1]+1;
for(i=1;i<=n;++i)ST[i][0]=ht[i];
for(j=1;(1<<j)<=n;j++)
for(i=1;i+(1<<j)-1<=n;i++)
ST[i][j]=min(ST[i][j-1],ST[i+(1<<(j-1))][j-1]);
return ;
}
int Query(int l,int r)//st表的O(1)查询
{
int tmp=lg2[r-l+1];
return min(ST[l][tmp],ST[r-(1<<tmp)+1][tmp]);
}
long long query(int st,int l)//查找有多少个子串
{
//l为这个回文子串的长度
//st位这个以这个回文子串开头的后缀的排名
int ql=st,qr=st;
int L=1,R=st;
while(R>=L)
{
int mid=(L+R)/2;
int pan=Query(mid,st);
if (pan>=l){ql=mid-1;R=mid-1;}
else L=mid+1;
}
//cout<<slen<<endl;
L=st,R=slen;
while(R>=L)
{
int mid=(L+R)/2;
int pan=Query(st+1,mid);
if (pan>=l) {qr=mid;L=mid+1;}
else R=mid-1;
}
int num=qr-ql+1;
//if (l==4)
//cout<<qr<<' '<<ql<<endl;
return (long long)num*(long long)l;
}
char tmp[maxn*2];
int len[maxn*2];
int init(char *str)//manacher算法中添加'#'
{
int tlen=strlen(str);
tmp[0]='@';
for (int i=1;i<=2*tlen;i+=2)
{
tmp[i]='#';
tmp[i+1]=str[i/2];
}
tmp[2*tlen+1]='#';
tmp[2*tlen+2]='$';
tmp[2*tlen+3]=0;
return 2*tlen+1;
}
long long manacher(char *str,int tlen)//manacher算法在此题中的应用
{
int mx=0,po=0;
long long ans=0;
for (int i=1;i<=tlen;i++)
{
if (mx>i)
len[i]=min(mx-i,len[2*po-i]);
else
len[i]=1;
while(str[i-len[i]]==str[i+len[i]])
{
len[i]++;
if (str[i-len[i]]!='#')//如果当前更新的不是添加的'#'
ans=max(ans,query(rank[(i-len[i]+2)/2-1],len[i]-1));
}
if (len[i]+i>mx)
{
mx=len[i]+i;
po=i;
}
}
return ans;
}
int main()
{
scanf("%s",s);
int tlen=strlen(s);
slen=tlen;
for (int i=0;i<tlen;i++)
R[i]=s[i]-'a'+1;
R[tlen]=0;
da(R,SA,tlen+1,28);
getht(R,SA,tlen);
tlen=init(s);
long long zans=manacher(tmp,tlen);
printf("%lld\n",zans);
}