hihocoder1457【后缀自动机+拓扑排序DP】

题目链接

http://hihocoder.com/problemset/problem/1457

题解

所有乐章,用#隔开拼接起来建立sam。在自动机上按照拓扑序递推求出初始状态到当前点的不含#的路径数,然后用这个递推出初始到当前点所有子串的和。递推式见代码。

注意

我开始直接入度和递推都忽略自动机的#边。但这是不对的。因为这样忽略会导致有的点入度永远不是0,而不被算到。因为去掉#边可能导致某个入度和初始状态失去连接路径。所以拓扑就老老实实算上所有边弄,dp在忽略#边才对。

AC代码

#include
#include
#include
#include
using namespace std;
const int modd=1e9+7;
const int NN=2e6+20;
const int alph=11;
struct NODE
{
    int ch[alph];
    int len,fa;
    NODE(){memset(ch,0,sizeof(ch));len=0;fa=0;}
}dian[NN<<1];
//int is_pre[NN<<1];//标记是否是前缀
int las=1,tot=1;
void add(int c)
{
    //printf("%d\n",c);
    int p=las;int np=las=++tot;
    //is_pre[np]=1;//endpos[np]={i} where s[i]=c
    dian[np].len=dian[p].len+1;
    for(;p&&!dian[p].ch[c];p=dian[p].fa)dian[p].ch[c]=np;
    if(!p){dian[np].fa=1;}//以上为case 1
    else
    {
        int q=dian[p].ch[c];
        if(dian[q].len==dian[p].len+1)dian[np].fa=q;//以上为case 2
        else
        {
            int nq=++tot;dian[nq]=dian[q];//endpos不是i了
            dian[nq].len=dian[p].len+1;
            dian[q].fa=dian[np].fa=nq; 
            for(;p&&dian[p].ch[c]==q;p=dian[p].fa)dian[p].ch[c]=nq;//以上为case 3
        }
    }
}
int n;
//vectorparentcon[NN<<1];
char s[NN];
int ind[NN<<1];
void sam(){
    las=1,tot=1;
    n=strlen(s+1);
    for(int i=1;i<(n<<1);i++){
        //is_pre[i]=0;
        //parentcon[i].clear();
        ind[i]=0;
    }
    for(int i=1;i<=n;i++){
        if(s[i]=='#')add(10);
        else add(s[i]-'0');
    }
    for(int i=1;i<=tot;i++){
        //parentcon[dian[i].fa].push_back(i);
        for(int j=0;j<11;j++){
            int nex=dian[i].ch[j];
            if(nex)ind[nex]++;
        }
    }
}
long long numkind[NN<<1];
long long dp[NN<<1];
queue<int > q;
void get_dp(){
    while(!q.empty())q.pop();
    for(int i=1;i<=tot;i++){numkind[i]=0;dp[i]=0;}
    q.push(1);numkind[1]=1;
    while(!q.empty()){
        int cur=q.front();
        q.pop();
        for(int i=0;i<11;i++){//哭死
            int nex=dian[cur].ch[i];
            if(nex){
                ind[nex]--;
                if(i!=10){
                    numkind[nex]+=numkind[cur];
                    numkind[nex]%=modd;//if(nex==6)printf("*%d %lld\n",cur,numkind[6]);
                    dp[nex]=(dp[nex]+(dp[cur]*10ll+1ll*i*numkind[cur])) %modd;
                }
                if(ind[nex]==0)q.push(nex);
            }
        }
    }
}
char temp[NN>>1];
int main(){
    int t;
    scanf("%d",&t);
    int cnt=0;
    while(t--){
        scanf("%s",temp+1);
        int nn=strlen(temp+1);
        for(int i=1;i<=nn;i++){
            s[++cnt]=temp[i];
        }
        s[++cnt]='#';
    }
    sam();
    // for(int i=1;i<=tot;i++){
    //     printf("%d\n",dian[i].fa);
    //     for(int j=0;j<11;j++)printf("%d ",dian[i].ch[j]);
    //     printf("\n");
    // }
    get_dp();
    // for(int i=1;i<=tot;i++){
    //     printf("%lld\n",numkind[i]);
    // }
    long long ans=0;
    for(int i=1;i<=tot;i++){
        ans+=dp[i];
        ans%=modd;
    }
    printf("%lld\n",ans);
    return 0;
}

你可能感兴趣的:(sam,dp)