BZOJ4598: [Sdoi2016]模式字符串

BZOJ4598

求树上满足某些条件的点对,首先就可以想到点分治。
然后又与什么字符串匹配有关。 KMP,AC 之类的好像不太好用。。那就哈希吧!
添加答案的时候有两种情况:
BZOJ4598: [Sdoi2016]模式字符串_第1张图片

那么就分别维护从上到下的链和从下到上的链。不是所有链都存的,仅当“从该点到当前根的一段是若干个模式串的前缀或者后缀”时才存。
发现当长度为 a 时,不仅 ma 可以更新答案,长度为 kma 的也可以。那这样岂不是每次更新都是 O(n/m) ?存入答案的时候,其实长度为 x x+km 是等价的。取个模再存就好了。

看起来好像不是很难。。但是蒟蒻表示:好多优化啊!! QAQ 。。还是要把细节都想清白了再打。。

【代码】

#include 
#include 
#include 
#include 
#include 
#include 
#define N 1000005
#define Mod 1000000007
#define INF 0x7fffffff
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const ull base=31;

ll read()
{
    ll x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-') f=-1;ch=getchar();}
    while(isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
    return x*f;
}

int sum,T,n,m,cnt,rt,s1,s2;
ll ans;
int b[N<<1],p[N],nextedge[N<<1];
int sz[N],f[N],d[N];
int Cnt[N],ccnt[N],st[N],sst[N];
char s[N];bool Flag[N];
ull a[N],ha[N],Ha[N],P[N],hash[N],Hash[N],deep[N],Deep[N];

void Pre()
{
    P[0]=1;
    for(int i=1;i1]*base;
}

void Add(int x,int y){
    cnt++;
    b[cnt]=y;
    nextedge[cnt]=p[x];
    p[x]=cnt;
}

void Anode(int x,int y){
    Add(x,y);Add(y,x);
}

void Input_Init()
{
    n=read(),m=read();rt=ans=0,f[0]=INF;sum=n;
    cnt=0;for(int i=1;i<=n;i++) p[i]=Flag[i]=deep[i]=0;
    scanf("%s",s+1);
    for(int i=1;i<=n;i++) a[i]=s[i]-'A'+1;
    for(int i=1;iint x,y;
        x=read(),y=read();
        Anode(x,y);
    }
    scanf("%s",s+1);
    for(int i=1;i<=m;i++) ha[i]=s[i]-'A'+1,Ha[i]=s[m-i+1]-'A'+1;
    for(int i=1;i<=n;i++) hash[i]=hash[i-1]+ha[(i-1)%m+1]*P[i-1],Hash[i]=Hash[i-1]+P[i-1]*Ha[(i-1)%m+1];
}

void Get_Root(int x,int fa)
{
    f[x]=0;sz[x]=1;
    for(int i=p[x];i;i=nextedge[i])
    {
        int v=b[i];if(v==fa||Flag[v]) continue;
        Get_Root(v,x);
        sz[x]+=sz[v];
        f[x]=max(f[x],sz[v]);
    }
    f[x]=max(f[x],sum-sz[x]);
    rt=f[rt]>f[x]?x:rt;
}

void Get_deep(int x,int fa)
{
    if(Hash[d[x]]==deep[x]&&a[x]==Ha[1]) st[++s1]=x;
    if(hash[d[x]]==deep[x]&&a[x]==ha[1]) sst[++s2]=x;
    for(int i=p[x];i;i=nextedge[i])
    {
        int v=b[i];if(v==fa||Flag[v]) continue;
        deep[v]=deep[x]*base+a[v];
        d[v]=d[x]+1;
        Get_deep(v,x);
    }
}

void Calc(int x)
{
    for(int i=0;i<=m;i++) Cnt[i]=ccnt[i]=0;
    if(Ha[1]==a[x]) ccnt[1]=1;
    if(Ha[m]==a[x]) Cnt[1]=1;
    for(int i=p[x];i;i=nextedge[i])
    {
        int v=b[i];if(Flag[v]) continue;
        s1=s2=0;d[v]=1;deep[v]=a[v];
        Get_deep(v,x);
        for(int j=1;j<=s1;j++){
            int t=st[j],pos=m-d[t]%m;
            if(pos==0) pos=m;
            ans+=Cnt[pos];
        }
        for(int j=1;j<=s2;j++){
            int t=sst[j],pos=m-d[t]%m;
            if(pos==0) pos=m;
            ans+=ccnt[pos];
        }
        for(int j=1;j<=s1;j++) {
            int t=st[j];int pos=d[t]%m+1;
            if(a[x]==Ha[pos]) ccnt[pos]++;
        }
        for(int j=1;j<=s2;j++) {
            int t=sst[j];int pos=d[t]%m+1;
            if(a[x]==ha[pos]) Cnt[pos]++;
        }
    }
}

void Work(int x)
{
    Calc(x);Flag[x]=1;
    for(int i=p[x];i;i=nextedge[i])
    {
        int v=b[i];if(Flag[v]) continue;
        sum=sz[v];rt=0;if(sum<m) continue;
        Get_Root(v,0);
        Work(rt);
    }
}

int main()
{
    T=read();
    Pre();
    while(T--)
    {
        Input_Init();
        Get_Root(1,0);
        Work(rt);
        printf("%lld\n",ans);
    }
    return 0;
}

你可能感兴趣的:(点分治,哈希)