题目大意:
求一个字符串满足
1、w是s1的子串
2、w是s2的子串
3、s3不是w的子串
4、w的长度应尽可能大
刚开始想用SAM做,想了半天发现不是很可做啊。
还是好好想后缀数组的做法吧。这道题最恶心的限制就是限制3:s3不是w的子串。
那么如果没有这个限制该怎么做呢?我们用间隔符连接s1,s2,用后缀数组求出height数组。答案一定出自相邻两个后缀的height,且两个后缀分别出自s1,s2。为什么?因为如果是一段区间的height取min的话只可能变小不可能变大。
限制考虑限制3.我们对于s3建立适配函数,然后利用kmp求出s1,s2中s3每次出现的开头位置,用st表维护一段区间中是否有s3出现的开头位置。
我们在用height更新答案的时候,要考虑限制3.二分求出[sa[i],sa[i]+height[i]-len(s3)]中s3出现的最靠前的位置pos。然后对pos-sa[i]+1+len(s3)-2取min,再去更新答案。
#include
#include
#include
#include
#include
#define N 100003
using namespace std;
int t[N],a[N],xx[N],yy[N],*x,*y,height[N],rank[N],sa[N];
int n,m,len,len1,pd[N],pd1[N],pd2[N],v[N],p,st[20][N],L[N];
char s1[N],s2[N],s3[N];
void get_fail()
{
t[0]=-1; int j;
for (int i=0;iwhile (j!=-1&&s3[j]!=s3[i]) j=t[j];
t[i+1]=++j;
}
}
void kmp(char s[N],int a[N],int l)
{
int i=0; int j=0;
while (j<=l) {
if (s3[i]==s[j]||i==-1) i++,j++;
else i=t[i];
if (i==len1) {
a[j-len1]=1;
i=t[i];
}
}
}
int cmp(int i,int j,int k)
{
return y[i]==y[j]&&(i+k>len?-1:y[i+k])==(j+k>len?-1:y[j+k]);
}
void get_sa()
{
x=xx; y=yy; int m1=30;
for (int i=1;i<=len;i++) v[x[i]=a[i]]++;
for (int i=1;i<=m1;i++) v[i]+=v[i-1];
for (int i=len;i>=1;i--) sa[v[x[i]]--]=i;
for (int k=1;k<=len;k<<=1) {
p=0;
for (int i=len-k+1;i<=len;i++) y[++p]=i;
for (int i=1;i<=len;i++)
if (sa[i]>k) y[++p]=sa[i]-k;
for (int i=1;i<=m1;i++) v[i]=0;
for (int i=1;i<=len;i++) v[x[y[i]]]++;
for (int i=1;i<=m1;i++) v[i]+=v[i-1];
for (int i=len;i>=1;i--) sa[v[x[y[i]]]--]=y[i];
swap(x,y); p=2; x[sa[1]]=1;
for (int i=2;i<=len;i++)
x[sa[i]]=cmp(sa[i],sa[i-1],k)?p-1:p++;
if (p>len) break;
m1=p+1;
}
for (int i=1;i<=len;i++) rank[sa[i]]=i;
p=0;
for (int i=1;i<=len;i++) {
if (rank[i]==1) continue;
int j=sa[rank[i]-1];
while (i+p<=len&&j+p<=len&&a[i+p]==a[j+p]) p++;
height[rank[i]]=p;
p=max(0,p-1);
}
}
int calc(int x,int y)
{
int k=L[y-x];
return max(st[k][x],st[k][y-(1<1]);
}
int divide(int l,int r)
{
int t=l; int ans=r+1;
while (l<=r) {
int mid=(l+r)/2;
if (calc(t,mid)) ans=min(ans,mid),r=mid-1;
else l=mid+1;
}
return ans;
}
int main()
{
freopen("a.in","r",stdin);
//freopen("my.out","w",stdout);
scanf("%s",s1); n=strlen(s1);
scanf("%s",s2); m=strlen(s2);
scanf("%s",s3); len1=strlen(s3);
get_fail();
kmp(s1,pd1,n); kmp(s2,pd2,m);
for (int i=1;i<=n;i++) a[i]=s1[i-1]-'a'+1,pd[i]=pd1[i-1];
a[n+1]=0; len=n+1;
for (int i=1;i<=m;i++) a[++len]=s2[i-1]-'a'+1,pd[len]=pd2[i-1];
get_sa();
for (int i=1;i<=len;i++) st[0][i]=pd[i];
for (int i=1;i<=17;i++)
for (int j=1;j<=len;j++)
if (j+(1<1<=len) st[i][j]=max(st[i-1][j],st[i-1][j+(1<<(i-1))]);
int j=0;
for (int i=1;i<=len;i++) {
if (1<<(j+1)<=i) j++;
L[i]=j;
}
int ans=0;
for (int i=2;i<=len;i++)
if (sa[i]<=n&&sa[i-1]>n+1||sa[i]>n+1&&sa[i-1]<=n) {
int t=height[i];
int pos=divide(sa[i],sa[i]+height[i]-len1);
t=min(t,pos-sa[i]+1+len1-2);
ans=max(ans,t);
}
printf("%d\n",ans);
}