[bzoj2555]substring 解题报告

考虑用splay维护sa,每次比较的时候二分+hash。注意要在两端加-∞和+∞的点。
时间复杂度 O((n+q)log2n+m) (n是数据总长度,m是询问总长度)
但是普通的hash的话需要用long long+除法运算,一个点需要跑6s。所以改成自然溢出,瞬间只需要2s了。。。(要是被卡了怎么办。。)
听说如果用重量平衡树维护的话可以做到 O((n+q)logn+m) ,然而我并不会。。
代码:

#include<cstdio>
#include<iostream>
using namespace std;
#include<algorithm>
#include<cstring>
#include<cmath>
#include<ctime>
typedef long long LL;
const int N=6e5,M=3e6+5,Q=1e4+5;

int tot,csum,ctmp;

const LL base=317;
int pres[N+5];
int qpres[M];
int power[N+5];
int gethash(int s[],int l,int r){//(l,r]
    return s[r]-s[l]*power[r-l];
}
int getlcp(int s[],int u,int t[],int v){
    //++tot;
    //ctmp=clock();

    int l=0,r;//l<=ans<r
    if(min(u,v)<=50||gethash(s,u-50,u)==gethash(s,v-50,v))r=min(u,v)+1;
    else r=50;
    while(r-l>1){
        if(gethash(s,u-(l+r>>1),u)==gethash(t,v-(l+r>>1),v))l=l+r>>1;
        else r=l+r>>1;
        //++tot;
    }

    //csum+=clock()-ctmp;

    return l;
}

char qs[M];
char inis[N+5];
int mask;
void gettrue(char s[],int mask){
    int len=strlen(s);
    for(int i=0;i<len;++i){
        mask=(mask*131+i)%len;
        swap(s[i],s[mask]);
    }
}
void in(char s[],int pres[]){
    scanf("%s",s);
    gettrue(s,mask);
    //printf("%s\n",s);
    int len=strlen(s);
    for(int i=0;i<len;++i)pres[i]=pres[i-1]*base+s[i];
}

const int smlinf=N+1,biginf=N+2;
int ch[N+5][2],fa[N+5],size[N+5];
void out(int node){
    printf("%d={ch[0]=%d,ch[1]=%d,fa=%d,size=%d}\n",node,ch[node][0],ch[node][1],fa[node],size[node]);
}
void outdfs(int node){
    out(node);
    if(ch[node][0])outdfs(ch[node][0]);
    if(ch[node][1])outdfs(ch[node][1]);
}
void pushup(int node){
    size[node]=size[ch[node][0]]+size[ch[node][1]]+(node&&node<=N);
    //printf("%d(%d)+%d(%d)->%d(%d)\n",size[ch[node][0]],ch[node][0],size[ch[node][1]],ch[node][1],size[node],node);
}
void rot(int node){
    int ftr=fa[node];
    bool dir=ch[ftr][1]==node;

    fa[node]=fa[ftr];
    fa[ftr]=node;
    fa[ch[node][!dir]]=ftr;

    ch[ftr][dir]=ch[node][!dir];
    ch[node][!dir]=ftr;
    if(ch[fa[node]][0]==ftr)ch[fa[node]][0]=node;
    else ch[fa[node]][1]=node;

    pushup(ftr);
}
int root;
void splay(int aim){
    for(int ftr;(ftr=fa[root])!=aim;rot(root))
        if(fa[ftr]!=aim)
            if((ch[fa[ftr]][0]==ftr)==(ch[ftr][0]==root))rot(ftr);
            else rot(root);
    pushup(root);
    //outdfs(root);
    //puts("-------");
}
void find(int node){
    //printf("find(%d)\n",node);

    for(int lcp;;)
        if(root==smlinf)
            if(ch[root][1])root=ch[root][1];
            else{
                ch[root][1]=node;
                break;
            } 
        else
            if(root==biginf)
                if(ch[root][0])root=ch[root][0];
                else{
                    ch[root][0]=node;
                    break;
                } 
            else{
                lcp=getlcp(pres,root,pres,node);
                //cout<<"Getlcp("<<root<<","<<node<<")="<<lcp<<endl;
                if(inis[root-lcp]<inis[node-lcp]){
                    //cout<<root<<"<"<<node<<endl;
                    if(ch[root][1])root=ch[root][1];
                    else{
                        ch[root][1]=node;
                        break;
                    }
                }
                else{
                    //cout<<root<<">"<<node<<endl;
                    if(ch[root][0])root=ch[root][0];
                    else{
                        ch[root][0]=node;
                        break;
                    }
                }
            }
    //cout<<"Add at "<<root<<endl;
    fa[node]=root;
    root=node;
    splay(0);
}
int query(){
    int qlen=strlen(qs+1);
    int lans=smlinf,rans=biginf;
    int lcp;

    for(;;)
        if(root==smlinf)
            if(ch[root][1])root=ch[root][1];
            else break;
        else
            if(root==biginf)
                if(ch[root][0])root=ch[root][0];
                else break;
            else
                if(gethash(pres,root-qlen,root)==gethash(qpres,0,qlen))
                    if(ch[root][0])root=ch[root][0];
                    else break;
                else{
                    lcp=getlcp(pres,root,qpres,qlen);
                    //cout<<"Getlcp("<<root<<",qs)="<<lcp<<endl;
                    if(inis[root-lcp]<qs[qlen-lcp]){
                        lans=root;
                        if(ch[root][1])root=ch[root][1];
                        else break;
                    }
                    else
                        if(ch[root][0])root=ch[root][0];
                        else break;
                }
    splay(0);

    for(;;)
        if(root==smlinf)
            if(ch[root][1])root=ch[root][1];
            else break;
        else
            if(root==biginf)
                if(ch[root][0])root=ch[root][0];
                else break;
            else
                if(gethash(pres,root-qlen,root)==gethash(qpres,0,qlen))
                    if(ch[root][1])root=ch[root][1];
                    else break;
                else{
                    //cout<<root<<":"<<gethash(pres,root-qlen,root)<<","<<gethash(qpres,0,qlen)<<endl;
                    lcp=getlcp(pres,root,qpres,qlen);
                    //cout<<"Getlcp("<<root<<",qs)="<<lcp<<endl;
                    if(inis[root-lcp]<qs[qlen-lcp]){
                        //cout<<root<<"<\n";
                        if(ch[root][1])root=ch[root][1];
                        else break;
                    }
                    else{
                        //cout<<root<<">\n";
                        rans=root;
                        if(ch[root][0])root=ch[root][0];
                        else break;
                    }
                }
    splay(0);

    root=lans;
    splay(0);

    root=rans;
    splay(lans);
    root=lans;
    pushup(root);

    //cout<<"lans="<<lans<<",rans="<<rans<<endl;

    return size[ch[ch[root][1]][0]];
}
int main(){
    freopen("bzoj_2555.in","r",stdin);
    freopen("bzoj_2555.out","w",stdout);

    power[0]=1;
    for(int i=1;i<=N;++i)power[i]=power[i-1]*base;

    for(int i=N;i;--i)size[i]=1;

    int q;
    scanf("%d",&q);

    scanf("%s",inis+1);
    int n=strlen(inis+1);
    for(int i=1;i<=n;++i)pres[i]=pres[i-1]*base+inis[i];
    ch[smlinf][1]=biginf,fa[biginf]=smlinf;
    root=smlinf;
    find(1);
    for(int i=2;i<=n;++i)find(i);

    //cout<<csum<<endl;
    //return 0;
    //puts("\n\n");

    char type[10];
    int ans;
    while(q--){
        scanf("%s",type);
        //printf("%s ",type);
        if(type[0]=='A'){
            in(inis+n+1,pres+n+1);
            int len=strlen(inis+n+1);
            for(int i=1;i<=len;++i)find(n+i);
            n+=len;
        }
        else{
            in(qs+1,qpres+1);
            ans=query();
            printf("%d\n",ans);
            mask^=ans;
        }
    }
}

你可能感兴趣的:(hash,SA,平衡树)