给你一个字符串 S S S,然后有 q q q个询问。对于第 i i i个询问,给出两个数字 k i k_i ki和 l i l_i li序列 a 1 , a 2 , . . . , a k i a_1,a_2,...,a_{k_i} a1,a2,...,aki和 b 1 , b 2 , . . . , b l i b_1,b_2,...,b_{l_i} b1,b2,...,bli,让你求
∑ i = 1 i = k ∑ j = 1 j = l LCP ( s [ a i … n ] , s [ b j … n ] ) \sum\limits_{i = 1}^{i = k} \sum\limits_{j = 1}^{j = l}{\text{LCP}(s[a_i \dots n], s[b_j \dots n])} i=1∑i=kj=1∑j=lLCP(s[ai…n],s[bj…n])
这题有很多种做法,我说一下我一眼想到的两种方法。
第一种方法,可以参见 POJ 3415 方法是类似的。对于每个询问,我们把给出的两个序列合并成一个,按照 s a sa sa进行排序。然后从前往后扫一遍,显然 h e i g h t height height的最小值是具有单调性的,可以维护一个单调栈,每次移动一个区间,用这一个区间的 h e i g h t height height最小值去更新单调栈。正反各自跑一次即可求出两个序列分别作为左右端点时的答案。这样对于每个读入的序列,都只是 O ( 1 ) O(1) O(1)的处理,复杂度就是求 s a sa sa的 O ( N l o g N + ∑ i = 1 i = q k i + ∑ i = 1 i = q l i ) O(NlogN+\sum\limits_{i = 1}^{i = q}{k_i}+\sum\limits_{i = 1}^{i = q}{l_i}) O(NlogN+i=1∑i=qki+i=1∑i=qli)。
第二种方法,还是要按照 s a sa sa排序,然后根据 h e i g h t height height的大小,分治去求答案。首先,找到总区间的最小值的位置,然后在两个序列中二分这个最小值的位置,看在最小值点左右的点的数目,可以统计一下对答案的贡献。然后以最小值点为中间拆开继续分治。这种方法,对于每个区间,只需要 O ( 1 ) O(1) O(1)的找区间最小值的位置,所以均摊下来总的复杂度是 O ( N l o g N + Q l o g Q ) O(NlogN+QlogQ) O(NlogN+QlogQ),其中 Q = ∑ i = 1 i = q k i + ∑ i = 1 i = q l i Q=\sum\limits_{i = 1}^{i = q}{k_i}+\sum\limits_{i = 1}^{i = q}{l_i} Q=i=1∑i=qki+i=1∑i=qli。
PS:一开始认为方法一好写,然后复杂度好像低一点,但是最后写了好久才调出来。然后写博客的时候发现方法二也好写,于是又写了一遍,只用了一个小时左右,而且事实证明还更快。
在不加快读的情况下,方法一: 1138 m s 1138ms 1138ms,方法二: 280 m s 280ms 280ms。
方法一:
#include
#define INF 0x3f3f3f3f
#define eps 1e-5
#define pi 3.141592653589793
#define LL long long
#define pb push_back
#define fi first
#define se second
#define lb lower_bound
#define ub upper_bound
#define bug(x) cerr<<#x<<" : "<
#define sc(x) scanf("%d",&x)
#define scc(x,y) scanf("%d%d",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;
const int mod = 1e9 + 7;
const int N = 200010;
char ss[N]; int dp[N][20];
struct node{int lcp,id,t;} stk[N];
std::vector<pair<int,int> > v;
int sa[N],Rank[N],h[N],n;
int xx[N],yy[N],c[N]; char *s;
bool cmp(int *s,int x,int y,int k)
{return (s[x]==s[y])&&(s[x+k]==s[y+k]);}
void ins(char *str) {s=str;n=strlen(s)+1;}
void DA()
{
memset(c,0,sizeof(c));
int *x=xx,*y=yy,m=130,*t,i;
for(i=0;i<n;i++) x[i]=s[i];
for(i=0;i<n;i++) c[x[i]]++;
for(i=1;i<m;i++) c[i]+=c[i-1];
for(i=n-1;i>=0;i--) sa[--c[x[i]]]=i;
for(int k=1,tot=0;tot<n;k<<=1,m=tot)
{
memset(c,0,sizeof(c));
for(i=0;i<n;i++) c[x[i]]++;
for(i=1;i<m;i++) c[i]+=c[i-1];
for(i=n-k,tot=0;i<n;i++) y[tot++]=i;
for(i=0;i<n;i++) if (sa[i]>=k) y[tot++]=sa[i]-k;
for(i=n-1;i>=0;i--) sa[--c[x[y[i]]]]=y[i];
for(i=tot=1,t=x,x=y,y=t,x[sa[0]]=0;i<n;i++)
x[sa[i]]=cmp(y,sa[i-1],sa[i],k)?tot-1:tot++;
}
}
void cal_height()
{
int i,j,k=0;
for(i=1;i<n;i++) Rank[sa[i]]=i;
for(i=0;i<n-1;h[Rank[i++]]=k)
for(k?k--:0,j=sa[Rank[i]-1];s[i+k]==s[j+k];k++);
int m=floor(log(n+0.0)/log(2.0));
for(int i=1;i<=n;i++) dp[i][0]=h[i];
for(int i=1;i<=m;i++)
for(int j=n;j;j--)
{
dp[j][i]=dp[j][i-1];
if(j+(1<<(i-1))<=n) dp[j][i]=min(dp[j][i],dp[j+(1<<(i-1))][i-1]);
}
}
inline int lcp(int l,int r)
{
int m=31-__builtin_clz(r-l+1);
return min(dp[l][m],dp[r-(1<<m)+1][m]);
}
inline bool cmpp(pair<int,int> a,pair<int,int> b)
{
return a.fi==b.fi?a.se>b.se:a.fi<b.fi;
}
unordered_map<int,bool> mp;
int main(int argc, char const *argv[])
{
int n,q;
scc(n,q);
scanf("%s",ss);
ins(ss); DA();
cal_height();
while(q--)
{
LL ans=0;
v.clear();
mp.clear();
int k,l; scc(k,l);
for(int i=1;i<=k;i++)
{
int x; sc(x); mp[x]=1; x--;
v.pb({Rank[x]+1,1});
}
for(int i=1;i<=l;i++)
{
int x; sc(x);
if (mp.count(x)) ans+=n-x+1;
x--; v.pb({Rank[x]+1,0});
}
sort(v.begin(),v.end());
for(int t=0;t<2;t++)
{
LL res=0; int top=0;
for(auto i:v)
{
int x=i.fi,y=i.se;
node now={h[x],x,y^t};
if (!(y^t))
{
if (!top||x==stk[top].id) continue;
if (x-1>=stk[top].id+1)
{
int LCP=lcp(stk[top].id+1,x-1),cnt=0;
while(top&&LCP<=stk[top].lcp)
{
res-=(LL)stk[top].t*(stk[top].lcp-LCP);
cnt+=stk[top--].t;
}
if (cnt&&LCP) stk[++top]={LCP,x-1,cnt};
}
ans+=res;
} else
{
int LCP=now.lcp,cnt=0;
if (top)
if (stk[top].id+1<=x)
LCP=min(LCP,lcp(stk[top].id+1,x));
res+=now.lcp;
while(top&&LCP<=stk[top].lcp)
{
res-=(LL)stk[top].t*(stk[top].lcp-LCP);
cnt+=stk[top--].t;
}
if (cnt&&LCP) stk[++top]={LCP,x,cnt};
stk[++top]=now;
}
}
if (!t) sort(v.begin(),v.end(),cmpp);
}
printf("%lld\n",ans);
}
return 0;
}
方法二:
#include
#define INF 0x3f3f3f3f
#define eps 1e-5
#define pi 3.141592653589793
#define LL long long
#define pb push_back
#define fi first
#define se second
#define lb lower_bound
#define ub upper_bound
#define bug(x) cerr<<#x<<" : "<
#define sc(x) scanf("%d",&x)
#define scc(x,y) scanf("%d%d",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;
const int mod = 1e9 + 7;
const int N = 200010;
char ss[N]; int m,q;
int a[N],b[N]; LL ans;
pair<int,int> dp[N][18];
int sa[N],Rank[N],h[N],n;
int xx[N],yy[N],c[N]; char *s;
bool cmp(int *s,int x,int y,int k)
{return (s[x]==s[y])&&(s[x+k]==s[y+k]);}
void ins(char *str) {s=str;n=strlen(s)+1;}
inline void DA()
{
memset(c,0,sizeof(c));
int *x=xx,*y=yy,m=130,*t,i;
for(i=0;i<n;i++) x[i]=s[i];
for(i=0;i<n;i++) c[x[i]]++;
for(i=1;i<m;i++) c[i]+=c[i-1];
for(i=n-1;i>=0;i--) sa[--c[x[i]]]=i;
for(int k=1,tot=0;tot<n;k<<=1,m=tot)
{
memset(c,0,sizeof(c));
for(i=0;i<n;i++) c[x[i]]++;
for(i=1;i<m;i++) c[i]+=c[i-1];
for(i=n-k,tot=0;i<n;i++) y[tot++]=i;
for(i=0;i<n;i++) if (sa[i]>=k) y[tot++]=sa[i]-k;
for(i=n-1;i>=0;i--) sa[--c[x[y[i]]]]=y[i];
for(i=tot=1,t=x,x=y,y=t,x[sa[0]]=0;i<n;i++)
x[sa[i]]=cmp(y,sa[i-1],sa[i],k)?tot-1:tot++;
}
}
inline void cal_height()
{
int i,j,k=0;
for(i=1;i<n;i++) Rank[sa[i]]=i;
for(i=0;i<n-1;h[Rank[i++]]=k)
for(k?k--:0,j=sa[Rank[i]-1];s[i+k]==s[j+k];k++);
int m=floor(log(n+0.0)/log(2.0));
for(int i=1;i<=n;i++) dp[i][0]={h[i],i};
for(int i=1;i<=m;i++)
for(int j=n;j;j--)
{
dp[j][i]=dp[j][i-1];
if(j+(1<<(i-1))<=n) dp[j][i]=min(dp[j][i],dp[j+(1<<(i-1))][i-1]);
}
}
inline pair<int,int> lcp(int l,int r)
{
int m=31-__builtin_clz(r-l+1);
return min(dp[l][m],dp[r-(1<<m)+1][m]);
}
void solve(int l,int r,int L,int R)
{
int ll=min(a[l],b[L]);
int rr=max(a[r],b[R]);
if (ll==rr) {ans+=m-sa[a[l]];return;}
pair<int,int> tmp=lcp(ll+1,rr);
int pos=tmp.se-1,lcp=tmp.fi;
int pos1=ub(a+l,a+r+1,pos)-a;
int pos2=ub(b+L,b+R+1,pos)-b;
ans+=((LL)(pos1-l)*(R-pos2+1)+(LL)(pos2-L)*(r-pos1+1))*lcp;
if (pos1>l&&pos2>L) solve(l,pos1-1,L,pos2-1);
if (pos1<=r&&pos2<=R) solve(pos1,r,pos2,R);
}
namespace IO{
#define BUF_SIZE 100000
#define OUT_SIZE 100000
#define ll long long
//fread->read
bool IOerror=0;
inline char nc(){
static char buf[BUF_SIZE],*p1=buf+BUF_SIZE,*pend=buf+BUF_SIZE;
if (p1==pend){
p1=buf; pend=buf+fread(buf,1,BUF_SIZE,stdin);
if (pend==p1){IOerror=1;return -1;}
}
return *p1++;
}
inline bool blank(char ch){return ch==' '||ch=='\n'||ch=='\r'||ch=='\t';}
inline void read(int &x){
bool sign=0; char ch=nc(); x=0;
for (;blank(ch);ch=nc());
if (IOerror)return;
if (ch=='-')sign=1,ch=nc();
for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
if (sign)x=-x;
}
inline void read(ll &x){
bool sign=0; char ch=nc(); x=0;
for (;blank(ch);ch=nc());
if (IOerror)return;
if (ch=='-')sign=1,ch=nc();
for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
if (sign)x=-x;
}
inline void read(char *s){
char ch=nc();
for (;blank(ch);ch=nc());
if (IOerror)return;
for (;!blank(ch)&&!IOerror;ch=nc())*s++=ch;
*s=0;
}
}
struct Ostream_fwrite{
char *buf,*p1,*pend;
Ostream_fwrite(){buf=new char[BUF_SIZE];p1=buf;pend=buf+BUF_SIZE;}
void out(char ch){
if (p1==pend){
fwrite(buf,1,BUF_SIZE,stdout);p1=buf;
}
*p1++=ch;
}
void print(LL x){
static char s[15],*s1;s1=s;
if (!x)*s1++='0';if (x<0)out('-'),x=-x;
while(x)*s1++=x%10+'0',x/=10;
while(s1--!=s)out(*s1);
}
void print(char *s){while (*s)out(*s++);}
void flush(){if (p1!=buf){fwrite(buf,1,p1-buf,stdout);p1=buf;}}
~Ostream_fwrite(){flush();}
}Ostream;
inline void print(LL x){Ostream.print(x);}
inline void print(char *s){Ostream.print(s);}
inline void println(LL x){print(x);print("\n");}
using namespace IO;
int main(int argc, char const *argv[])
{
read(m);
read(q);
read(ss);
ins(ss); DA();
cal_height();
while(q--)
{
ans=0;
int k,l; read(k); read(l);
for(int i=1;i<=k;i++)
{
int x; read(x); x--;
a[i]=Rank[x];
}
for(int i=1;i<=l;i++)
{
int x; read(x); x--;
b[i]=Rank[x];
}
sort(a+1,a+1+k);
sort(b+1,b+1+l);
solve(1,k,1,l);
println(ans);
}
return 0;
}