CodeForces - 1073G Yet Another LCP Problem(后缀数组 + 单调栈 / 分治)

CodeForces - 1073G Yet Another LCP Problem(后缀数组 + 单调栈 / 分治)_第1张图片
CodeForces - 1073G Yet Another LCP Problem(后缀数组 + 单调栈 / 分治)_第2张图片

大致题意

给你一个字符串 S S S,然后有 q q q个询问。对于第 i i i个询问,给出两个数字 k i k_i ki l i l_i li序列 a 1 , a 2 , . . . , a k i a_1,a_2,...,a_{k_i} a1,a2,...,aki b 1 , b 2 , . . . , b l i b_1,b_2,...,b_{l_i} b1,b2,...,bli,让你求
∑ i = 1 i = k ∑ j = 1 j = l LCP ( s [ a i … n ] , s [ b j … n ] ) \sum\limits_{i = 1}^{i = k} \sum\limits_{j = 1}^{j = l}{\text{LCP}(s[a_i \dots n], s[b_j \dots n])} i=1i=kj=1j=lLCP(s[ain],s[bjn])

做法

这题有很多种做法,我说一下我一眼想到的两种方法。

第一种方法,可以参见 POJ 3415 方法是类似的。对于每个询问,我们把给出的两个序列合并成一个,按照 s a sa sa进行排序。然后从前往后扫一遍,显然 h e i g h t height height的最小值是具有单调性的,可以维护一个单调栈,每次移动一个区间,用这一个区间的 h e i g h t height height最小值去更新单调栈。正反各自跑一次即可求出两个序列分别作为左右端点时的答案。这样对于每个读入的序列,都只是 O ( 1 ) O(1) O(1)的处理,复杂度就是求 s a sa sa O ( N l o g N + ∑ i = 1 i = q k i + ∑ i = 1 i = q l i ) O(NlogN+\sum\limits_{i = 1}^{i = q}{k_i}+\sum\limits_{i = 1}^{i = q}{l_i}) O(NlogN+i=1i=qki+i=1i=qli)

第二种方法,还是要按照 s a sa sa排序,然后根据 h e i g h t height height的大小,分治去求答案。首先,找到总区间的最小值的位置,然后在两个序列中二分这个最小值的位置,看在最小值点左右的点的数目,可以统计一下对答案的贡献。然后以最小值点为中间拆开继续分治。这种方法,对于每个区间,只需要 O ( 1 ) O(1) O(1)的找区间最小值的位置,所以均摊下来总的复杂度是 O ( N l o g N + Q l o g Q ) O(NlogN+QlogQ) O(NlogN+QlogQ),其中 Q = ∑ i = 1 i = q k i + ∑ i = 1 i = q l i Q=\sum\limits_{i = 1}^{i = q}{k_i}+\sum\limits_{i = 1}^{i = q}{l_i} Q=i=1i=qki+i=1i=qli

代码

PS:一开始认为方法一好写,然后复杂度好像低一点,但是最后写了好久才调出来。然后写博客的时候发现方法二也好写,于是又写了一遍,只用了一个小时左右,而且事实证明还更快。

在不加快读的情况下,方法一: 1138 m s 1138ms 1138ms,方法二: 280 m s 280ms 280ms

方法一:

#include
#define INF 0x3f3f3f3f
#define eps 1e-5
#define pi 3.141592653589793
#define LL long long
#define pb push_back
#define fi first
#define se second
#define lb lower_bound
#define ub upper_bound
#define bug(x) cerr<<#x<<"      :   "<
#define sc(x) scanf("%d",&x)
#define scc(x,y) scanf("%d%d",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;

const int mod = 1e9 + 7;
const int N = 200010;

char ss[N]; int dp[N][20];
struct node{int lcp,id,t;} stk[N];
std::vector<pair<int,int> > v;

int sa[N],Rank[N],h[N],n;
int xx[N],yy[N],c[N]; char *s;
bool cmp(int *s,int x,int y,int k)
{return (s[x]==s[y])&&(s[x+k]==s[y+k]);}
void ins(char *str) {s=str;n=strlen(s)+1;}

void DA()
{
    memset(c,0,sizeof(c));
    int *x=xx,*y=yy,m=130,*t,i;
    for(i=0;i<n;i++) x[i]=s[i];
    for(i=0;i<n;i++) c[x[i]]++;
    for(i=1;i<m;i++) c[i]+=c[i-1];
    for(i=n-1;i>=0;i--) sa[--c[x[i]]]=i;
    for(int k=1,tot=0;tot<n;k<<=1,m=tot)
    {
        memset(c,0,sizeof(c));
        for(i=0;i<n;i++) c[x[i]]++;
        for(i=1;i<m;i++) c[i]+=c[i-1];
        for(i=n-k,tot=0;i<n;i++) y[tot++]=i;
        for(i=0;i<n;i++) if (sa[i]>=k) y[tot++]=sa[i]-k;
        for(i=n-1;i>=0;i--) sa[--c[x[y[i]]]]=y[i];
        for(i=tot=1,t=x,x=y,y=t,x[sa[0]]=0;i<n;i++)
            x[sa[i]]=cmp(y,sa[i-1],sa[i],k)?tot-1:tot++;
    }
}

void cal_height()
{
    int i,j,k=0;
    for(i=1;i<n;i++) Rank[sa[i]]=i;
    for(i=0;i<n-1;h[Rank[i++]]=k)
        for(k?k--:0,j=sa[Rank[i]-1];s[i+k]==s[j+k];k++);
    int m=floor(log(n+0.0)/log(2.0));
    for(int i=1;i<=n;i++) dp[i][0]=h[i];
    for(int i=1;i<=m;i++)
        for(int j=n;j;j--)
        {
            dp[j][i]=dp[j][i-1];
            if(j+(1<<(i-1))<=n) dp[j][i]=min(dp[j][i],dp[j+(1<<(i-1))][i-1]);
        }
}

inline int lcp(int l,int r)
{
    int m=31-__builtin_clz(r-l+1);
    return min(dp[l][m],dp[r-(1<<m)+1][m]);
}


inline bool cmpp(pair<int,int> a,pair<int,int> b)
{
    return a.fi==b.fi?a.se>b.se:a.fi<b.fi;
}

unordered_map<int,bool> mp;

int main(int argc, char const *argv[])
{
    int n,q;
    scc(n,q);
    scanf("%s",ss);
    ins(ss); DA();
    cal_height();
    while(q--)
    {
        LL ans=0;
        v.clear();
        mp.clear();
        int k,l; scc(k,l);
        for(int i=1;i<=k;i++)
        {
            int x; sc(x); mp[x]=1; x--;
            v.pb({Rank[x]+1,1});
        }
        for(int i=1;i<=l;i++)
        {
            int x; sc(x);
            if (mp.count(x)) ans+=n-x+1;
            x--; v.pb({Rank[x]+1,0});
        }
        sort(v.begin(),v.end());
        for(int t=0;t<2;t++)
        {
            LL res=0; int top=0;
            for(auto i:v)
            {
                int x=i.fi,y=i.se;
                node now={h[x],x,y^t};
                if (!(y^t))
                {
                    if (!top||x==stk[top].id) continue;
                    if (x-1>=stk[top].id+1)
                    {
                        int LCP=lcp(stk[top].id+1,x-1),cnt=0;
                        while(top&&LCP<=stk[top].lcp)
                        {
                            res-=(LL)stk[top].t*(stk[top].lcp-LCP);
                            cnt+=stk[top--].t;
                        }
                        if (cnt&&LCP) stk[++top]={LCP,x-1,cnt};
                    }
                    ans+=res;
                } else
                {
                    int LCP=now.lcp,cnt=0;
                    if (top)
                        if (stk[top].id+1<=x)
                            LCP=min(LCP,lcp(stk[top].id+1,x));
                    res+=now.lcp;
                    while(top&&LCP<=stk[top].lcp)
                    {
                        res-=(LL)stk[top].t*(stk[top].lcp-LCP);
                        cnt+=stk[top--].t;
                    }
                    if (cnt&&LCP) stk[++top]={LCP,x,cnt};
                    stk[++top]=now;
                }
            }
            if (!t) sort(v.begin(),v.end(),cmpp);
        }
        printf("%lld\n",ans);
    }
    return 0;
}

方法二:

#include
#define INF 0x3f3f3f3f
#define eps 1e-5
#define pi 3.141592653589793
#define LL long long
#define pb push_back
#define fi first
#define se second
#define lb lower_bound
#define ub upper_bound
#define bug(x) cerr<<#x<<"      :   "<
#define sc(x) scanf("%d",&x)
#define scc(x,y) scanf("%d%d",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;

const int mod = 1e9 + 7;
const int N = 200010;

char ss[N]; int m,q;
int a[N],b[N]; LL ans;
pair<int,int> dp[N][18];

int sa[N],Rank[N],h[N],n;
int xx[N],yy[N],c[N]; char *s;
bool cmp(int *s,int x,int y,int k)
{return (s[x]==s[y])&&(s[x+k]==s[y+k]);}
void ins(char *str) {s=str;n=strlen(s)+1;}

inline void DA()
{
    memset(c,0,sizeof(c));
    int *x=xx,*y=yy,m=130,*t,i;
    for(i=0;i<n;i++) x[i]=s[i];
    for(i=0;i<n;i++) c[x[i]]++;
    for(i=1;i<m;i++) c[i]+=c[i-1];
    for(i=n-1;i>=0;i--) sa[--c[x[i]]]=i;
    for(int k=1,tot=0;tot<n;k<<=1,m=tot)
    {
        memset(c,0,sizeof(c));
        for(i=0;i<n;i++) c[x[i]]++;
        for(i=1;i<m;i++) c[i]+=c[i-1];
        for(i=n-k,tot=0;i<n;i++) y[tot++]=i;
        for(i=0;i<n;i++) if (sa[i]>=k) y[tot++]=sa[i]-k;
        for(i=n-1;i>=0;i--) sa[--c[x[y[i]]]]=y[i];
        for(i=tot=1,t=x,x=y,y=t,x[sa[0]]=0;i<n;i++)
            x[sa[i]]=cmp(y,sa[i-1],sa[i],k)?tot-1:tot++;
    }
}

inline void cal_height()
{
    int i,j,k=0;
    for(i=1;i<n;i++) Rank[sa[i]]=i;
    for(i=0;i<n-1;h[Rank[i++]]=k)
        for(k?k--:0,j=sa[Rank[i]-1];s[i+k]==s[j+k];k++);
        int m=floor(log(n+0.0)/log(2.0));
    for(int i=1;i<=n;i++) dp[i][0]={h[i],i};
    for(int i=1;i<=m;i++)
        for(int j=n;j;j--)
        {
            dp[j][i]=dp[j][i-1];
            if(j+(1<<(i-1))<=n) dp[j][i]=min(dp[j][i],dp[j+(1<<(i-1))][i-1]);
        }
}

inline pair<int,int> lcp(int l,int r)
{
    int m=31-__builtin_clz(r-l+1);
    return min(dp[l][m],dp[r-(1<<m)+1][m]);
}

void solve(int l,int r,int L,int R)
{
    int ll=min(a[l],b[L]);
    int rr=max(a[r],b[R]);
    if (ll==rr) {ans+=m-sa[a[l]];return;}
    pair<int,int> tmp=lcp(ll+1,rr);
    int pos=tmp.se-1,lcp=tmp.fi;
    int pos1=ub(a+l,a+r+1,pos)-a;
    int pos2=ub(b+L,b+R+1,pos)-b;
    ans+=((LL)(pos1-l)*(R-pos2+1)+(LL)(pos2-L)*(r-pos1+1))*lcp;
    if (pos1>l&&pos2>L) solve(l,pos1-1,L,pos2-1);
    if (pos1<=r&&pos2<=R) solve(pos1,r,pos2,R);
}

namespace IO{
    #define BUF_SIZE 100000
    #define OUT_SIZE 100000
    #define ll long long
    //fread->read

    bool IOerror=0;
    inline char nc(){
        static char buf[BUF_SIZE],*p1=buf+BUF_SIZE,*pend=buf+BUF_SIZE;
        if (p1==pend){
            p1=buf; pend=buf+fread(buf,1,BUF_SIZE,stdin);
            if (pend==p1){IOerror=1;return -1;}
        }
        return *p1++;
    }
    inline bool blank(char ch){return ch==' '||ch=='\n'||ch=='\r'||ch=='\t';}
    inline void read(int &x){
        bool sign=0; char ch=nc(); x=0;
        for (;blank(ch);ch=nc());
        if (IOerror)return;
        if (ch=='-')sign=1,ch=nc();
        for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
        if (sign)x=-x;
    }
    inline void read(ll &x){
        bool sign=0; char ch=nc(); x=0;
        for (;blank(ch);ch=nc());
        if (IOerror)return;
        if (ch=='-')sign=1,ch=nc();
        for (;ch>='0'&&ch<='9';ch=nc())x=x*10+ch-'0';
        if (sign)x=-x;
    }
    inline void read(char *s){
        char ch=nc();
        for (;blank(ch);ch=nc());
        if (IOerror)return;
        for (;!blank(ch)&&!IOerror;ch=nc())*s++=ch;
        *s=0;
    }
}

struct Ostream_fwrite{
        char *buf,*p1,*pend;
        Ostream_fwrite(){buf=new char[BUF_SIZE];p1=buf;pend=buf+BUF_SIZE;}
        void out(char ch){
            if (p1==pend){
                fwrite(buf,1,BUF_SIZE,stdout);p1=buf;
            }
            *p1++=ch;
        }
        void print(LL x){
            static char s[15],*s1;s1=s;
            if (!x)*s1++='0';if (x<0)out('-'),x=-x;
            while(x)*s1++=x%10+'0',x/=10;
            while(s1--!=s)out(*s1);
        }
        void print(char *s){while (*s)out(*s++);}
        void flush(){if (p1!=buf){fwrite(buf,1,p1-buf,stdout);p1=buf;}}
        ~Ostream_fwrite(){flush();}
    }Ostream;
inline void print(LL x){Ostream.print(x);}
inline void print(char *s){Ostream.print(s);}
inline void println(LL x){print(x);print("\n");}

using namespace IO;

int main(int argc, char const *argv[])
{
    read(m);
    read(q);
    read(ss);
    ins(ss); DA();
    cal_height();
    while(q--)
    {
        ans=0;
        int k,l; read(k); read(l);
        for(int i=1;i<=k;i++)
        {
            int x; read(x); x--;
            a[i]=Rank[x];
        }
        for(int i=1;i<=l;i++)
        {
            int x; read(x); x--;
            b[i]=Rank[x];
        }
        sort(a+1,a+1+k);
        sort(b+1,b+1+l);
        solve(1,k,1,l);
        println(ans);
    }
    return 0;
}

你可能感兴趣的:(CodeForces,后缀数组)