在字符串 p 上定义函数 f(p) ,假设字符串的长度是 l ,那么
1≤n,q≤105,1≤kth≤n(n+1)2
既然题目询问的是回文子串的函数值,那么我们考虑构出回文树,这样每一个回文子串的 f 值我们就可以很方便地计算出来。
现在的问题是我怎么找到排名为 kth 的回文子串呢?因为一个字符串本质不同的回文子串只有 O(n) 个,那么我可以把它们拉出来排序,然后对出现次数做前缀和再二分查找来得到这个串。因此现在的问题变为我如何排序这 O(n) 个回文子串?
如果你使用回文平衡树那这个操作就是模板了。有没有其他方法呢?我们直接快排,比较两个子串的大小关系可以通过求出它们的 LCP 长度来实现,如果你对这个串做一遍后缀数组然后 RMQ ,那你的复杂度就可以做到 O(nlogn) ,但是我比较菜,于是就写了个二分加哈希,时间复杂度 O(nlog2n) 。
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cctype>
using namespace std;
typedef long long LL;
template <typename T>
void read(T &x)
{
x=0;
int f=1;char ch=getchar();
while (!isdigit(ch)) f=ch=='-'?-1:f,ch=getchar();
while (isdigit(ch)) x=x*10+ch-'0',ch=getchar();
x*=f;
}
int buf[30];
void write(int x)
{
if (x<0) putchar('-'),x=-x;
for (;x;x/=10) buf[++buf[0]]=x%10;
if (!buf[0]) buf[++buf[0]]=0;
for (;buf[0];putchar('0'+buf[buf[0]--]));
}
const int A=100001;
const int P=1000000007;
const int N=100050;
const int C=26;
const int MOD=998244353;
const int P1=67;
const int P2=89;
typedef pair<int,int> PI;
#define mkp(a,b) make_pair(a,b)
#define ft first
#define sd second
PI operator+(PI x,PI y){return mkp((x.ft+y.ft)%MOD,(x.sd+y.sd)%MOD);}
PI operator-(PI x,PI y){return mkp((x.ft-y.ft+MOD)%MOD,(x.sd-y.sd+MOD)%MOD);}
PI operator*(PI x,PI y){return mkp(1ll*x.ft*y.ft%MOD,1ll*x.sd*y.sd%MOD);}
PI POW[N],IPOW[N],p,ip,preh[N];
int node[N],pw[N],srt[N];
LL sum[N];
char s[N];
int n,q,cnt;
int quick_power(int x,int y)
{
int ret=1;
for (;y;y>>=1,x=1ll*x*x%MOD) if (y&1) ret=1ll*ret*x%MOD;
return ret;
}
struct Palindrome_tree
{
int fail[N],len[N],hash[N],st[N];
int nxt[N][C];
int tot,suf;
LL size[N];
int newnode()
{
fail[++tot]=0;
for (int c=0;c<C;++c) nxt[tot][c]=0;
size[tot]=0;
return tot;
}
void init()
{
tot=-1,len[newnode()]=0,len[newnode()]=-1;
fail[0]=fail[1]=1,suf=1;
}
int getfail(int x,int pos){return s[pos-len[x]-1]==s[pos]?x:getfail(fail[x],pos);}
void insert(int pos)
{
int p=getfail(suf,pos),c=s[pos]-'a';
if (!nxt[p][c])
{
int np=newnode();
len[np]=len[p]+2,fail[np]=nxt[getfail(fail[p],pos)][c],nxt[p][c]=np,st[np]=pos-len[np]+1;
hash[np]=(1ll*hash[p]*A+1ll*s[pos]*(1+(p!=1)*pw[len[np]-1]))%P;
}
++size[suf=nxt[p][c]];
}
void calc(){for (int x=tot;x>=1;--x) if (x!=1) size[fail[x]]+=size[x];}
}pam;
void pre()
{
pw[0]=1;
for (int i=1;i<=n;++i) pw[i]=1ll*pw[i-1]*A%P;
POW[0]=mkp(1,1),p=mkp(P1,P2);
for (int i=1;i<=n;++i) POW[i]=POW[i-1]*p;
IPOW[0]=mkp(1,1),ip=mkp(quick_power(P1,MOD-2),quick_power(P2,MOD-2));
for (int i=1;i<=n;++i) IPOW[i]=IPOW[i-1]*ip;
for (int i=1;i<=n;++i) preh[i]=preh[i-1]+(POW[i-1]*mkp(s[i]-'a',s[i]-'a'));
}
PI gethash(int x,int l){return (preh[x+l-1]-preh[x-1])*IPOW[x-1];}
int LCP(int x,int y,int len)
{
int ret=0,l=1,r=len;
for (int mid;l<=r;)
{
mid=l+r>>1;
if (gethash(x,mid)==gethash(y,mid)) l=(ret=mid)+1;
else r=mid-1;
}
return ret;
}
bool cmp(int x,int y)
{
int len=min(pam.len[x],pam.len[y]),lcp;
lcp=LCP(pam.st[x],pam.st[y],len);
return !(lcp==pam.len[y])&&(lcp==pam.len[x]||s[pam.st[x]+lcp]<s[pam.st[y]+lcp]);
}
int main()
{
freopen("fpalindrome.in","r",stdin),freopen("fpalindrome.out","w",stdout);
read(n),read(q);
scanf("%s",s+1),s[0]='#',pre();
pam.init();
for (int i=1;i<=n;++i) pam.insert(i),node[i]=pam.suf;
pam.calc();
for (int i=2;i<=pam.tot;++i) srt[++cnt]=i;
sort(srt+1,srt+cnt+1,cmp);
sum[0]=0;
for (int i=1;i<=cnt;++i) sum[i]=sum[i-1]+pam.size[srt[i]];
for (LL x;q--;putchar('\n'))
{
read(x);
if (x>sum[cnt]) write(-1);
else
{
int ret=0,l=1,r=cnt;
for (int mid;l<=r;)
{
mid=l+r>>1;
if (sum[mid]<x) l=(ret=mid)+1;
else r=mid-1;
}
write(pam.hash[srt[++ret]]);
}
}
fclose(stdin),fclose(stdout);
return 0;
}