[HackerRank-World CodeSprint 6]Functional Palindromes

题目大意

在字符串 p 上定义函数 f(p) ,假设字符串的长度是 l ,那么

f(p)=i=1lpiali

其中 pi 表示第 i 个字符的ASCII码。
现在给定一个长度为 n 的字符串 s ,有 q 个询问,每次询问字典序排名为 kth 的回文子串的 f 函数值。
两个本质相同,起始位置不同的回文子串视作两个不同的串。

1n,q105,1kthn(n+1)2

题目分析

既然题目询问的是回文子串的函数值,那么我们考虑构出回文树,这样每一个回文子串的 f 值我们就可以很方便地计算出来。
现在的问题是我怎么找到排名为 kth 的回文子串呢?因为一个字符串本质不同的回文子串只有 O(n) 个,那么我可以把它们拉出来排序,然后对出现次数做前缀和再二分查找来得到这个串。因此现在的问题变为我如何排序这 O(n) 个回文子串?
如果你使用回文平衡树那这个操作就是模板了。有没有其他方法呢?我们直接快排,比较两个子串的大小关系可以通过求出它们的 LCP 长度来实现,如果你对这个串做一遍后缀数组然后 RMQ ,那你的复杂度就可以做到 O(nlogn) ,但是我比较菜,于是就写了个二分加哈希,时间复杂度 O(nlog2n)

代码实现

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cctype>

using namespace std;

typedef long long LL;

template <typename T>
void read(T &x)
{
    x=0;
    int f=1;char ch=getchar();
    while (!isdigit(ch)) f=ch=='-'?-1:f,ch=getchar();
    while (isdigit(ch)) x=x*10+ch-'0',ch=getchar();
    x*=f;
}

int buf[30];

void write(int x)
{
    if (x<0) putchar('-'),x=-x;
    for (;x;x/=10) buf[++buf[0]]=x%10;
    if (!buf[0]) buf[++buf[0]]=0;
    for (;buf[0];putchar('0'+buf[buf[0]--]));
}

const int A=100001;
const int P=1000000007;
const int N=100050;
const int C=26;
const int MOD=998244353;
const int P1=67;
const int P2=89;

typedef pair<int,int> PI;
#define mkp(a,b) make_pair(a,b)
#define ft first
#define sd second

PI operator+(PI x,PI y){return mkp((x.ft+y.ft)%MOD,(x.sd+y.sd)%MOD);}
PI operator-(PI x,PI y){return mkp((x.ft-y.ft+MOD)%MOD,(x.sd-y.sd+MOD)%MOD);}
PI operator*(PI x,PI y){return mkp(1ll*x.ft*y.ft%MOD,1ll*x.sd*y.sd%MOD);}

PI POW[N],IPOW[N],p,ip,preh[N];
int node[N],pw[N],srt[N];
LL sum[N];
char s[N];
int n,q,cnt;

int quick_power(int x,int y)
{
    int ret=1;
    for (;y;y>>=1,x=1ll*x*x%MOD) if (y&1) ret=1ll*ret*x%MOD;
    return ret;
}

struct Palindrome_tree
{
    int fail[N],len[N],hash[N],st[N];
    int nxt[N][C];
    int tot,suf;
    LL size[N];

    int newnode()
    {
        fail[++tot]=0;
        for (int c=0;c<C;++c) nxt[tot][c]=0;
        size[tot]=0;
        return tot;
    }

    void init()
    {
        tot=-1,len[newnode()]=0,len[newnode()]=-1;
        fail[0]=fail[1]=1,suf=1;
    }

    int getfail(int x,int pos){return s[pos-len[x]-1]==s[pos]?x:getfail(fail[x],pos);}

    void insert(int pos)
    {
        int p=getfail(suf,pos),c=s[pos]-'a';
        if (!nxt[p][c])
        {
            int np=newnode();
            len[np]=len[p]+2,fail[np]=nxt[getfail(fail[p],pos)][c],nxt[p][c]=np,st[np]=pos-len[np]+1;
            hash[np]=(1ll*hash[p]*A+1ll*s[pos]*(1+(p!=1)*pw[len[np]-1]))%P;
        }
        ++size[suf=nxt[p][c]];
    }

    void calc(){for (int x=tot;x>=1;--x) if (x!=1) size[fail[x]]+=size[x];}
}pam;

void pre()
{
    pw[0]=1;
    for (int i=1;i<=n;++i) pw[i]=1ll*pw[i-1]*A%P;
    POW[0]=mkp(1,1),p=mkp(P1,P2);
    for (int i=1;i<=n;++i) POW[i]=POW[i-1]*p;
    IPOW[0]=mkp(1,1),ip=mkp(quick_power(P1,MOD-2),quick_power(P2,MOD-2));
    for (int i=1;i<=n;++i) IPOW[i]=IPOW[i-1]*ip;
    for (int i=1;i<=n;++i) preh[i]=preh[i-1]+(POW[i-1]*mkp(s[i]-'a',s[i]-'a'));
}

PI gethash(int x,int l){return (preh[x+l-1]-preh[x-1])*IPOW[x-1];}

int LCP(int x,int y,int len)
{
    int ret=0,l=1,r=len;
    for (int mid;l<=r;)
    {
        mid=l+r>>1;
        if (gethash(x,mid)==gethash(y,mid)) l=(ret=mid)+1;
        else r=mid-1;
    }
    return ret;
}

bool cmp(int x,int y)
{
    int len=min(pam.len[x],pam.len[y]),lcp;
    lcp=LCP(pam.st[x],pam.st[y],len);
    return !(lcp==pam.len[y])&&(lcp==pam.len[x]||s[pam.st[x]+lcp]<s[pam.st[y]+lcp]);
}

int main()
{
    freopen("fpalindrome.in","r",stdin),freopen("fpalindrome.out","w",stdout);
    read(n),read(q);
    scanf("%s",s+1),s[0]='#',pre();
    pam.init();
    for (int i=1;i<=n;++i) pam.insert(i),node[i]=pam.suf;
    pam.calc();
    for (int i=2;i<=pam.tot;++i) srt[++cnt]=i;
    sort(srt+1,srt+cnt+1,cmp);
    sum[0]=0;
    for (int i=1;i<=cnt;++i) sum[i]=sum[i-1]+pam.size[srt[i]];
    for (LL x;q--;putchar('\n'))
    {
        read(x);
        if (x>sum[cnt]) write(-1);
        else
        {
            int ret=0,l=1,r=cnt;
            for (int mid;l<=r;)
            {
                mid=l+r>>1;
                if (sum[mid]<x) l=(ret=mid)+1;
                else r=mid-1;
            }
            write(pam.hash[srt[++ret]]);
        }
    }
    fclose(stdin),fclose(stdout);
    return 0;
}

你可能感兴趣的:(哈希,后缀数组,OI,回文树,hackerrank)