一个字符串本质不同的回文子串的个数是 O(n) 级别的。
回文树上每个节点代表一个回文串。如果节点 a 两边加同一个字符 c 能到达 b,那么 ch[a][c]=b。
定义 fail[u] 表示 u 节点代表回文串的最长回文后缀。
定义 l[u] 表示 u 代表回文串的长度。
这样,回文树有两个“根”,一个根的子树内的回文串长度全是奇数,另一个子树全是偶数。fail 在两棵树上随便连,偶根的 fail 指向奇根。
增量构造,每次添加一个字符 s[i]=c,新出现的回文串只能是新字符串的最长回文后缀。记 last 表示s[1,…,i-1] 的最长回文后缀,沿着 fail 跳 last,直到某个 p 点左边字符等于 s[i],如果 p 没有 c 这个出边,ch[p][c]=++tot。类似的,想要得到新建点的 fail,继续沿着 p 的 fail 跳,直到某个 q 点左边字符等于 s[i],fail[tot]=q。
注意就是把奇根设为 1 偶根设为 0,这样如果一个点的 ch 为空就自动到了 0 节点。
回文树的父亲节点一定比子节点标号小。因此直接按标号大小就可以dfs整棵树。
int Insert(int x,int c)
{
int p=lst;
while(s[x-l[p]-1]^c) p=fail[p];
if(!ch[p][c])
{
int now=++tot,q=fail[p];
l[now]=l[p]+2;
while(s[x-l[q]-1]^c) q=fail[q];
fail[now]=ch[q][c];
ch[p][c]=now;
}
return lst=ch[p][c];
}
int main()
{
...
s[0]=-1,tot=1,l[1]=-1,fa[0]=1; //注意初始化
for(int i=1;i<=n;i++) pos[i]=Insert(i,s[i]);
...
}
统计每个节点对应串出现次数。考虑每次加一个字符会对哪些串有贡献,显然是那些在沿着fail的点。因此只需要沿着fail加上来就好了。
CF932G
给定一个字符串 s,划分成 p 1 , . . . , p 2 k p_1,...,p_{2k} p1,...,p2k,满足 p i = p 2 k − i + 1 p_i=p_{2k-i+1} pi=p2k−i+1,求划分方案数。
先挂一篇题解:https://blog.csdn.net/l_0_forever_lf/article/details/79494755
我们构造一个新的字符串: s 1 s n s 2 s n − 1 . . . s_1s_ns_2s_{n-1}... s1sns2sn−1...,发现原问题等价于把新字符串划分成若干个回文串,并且每个回文串长度都是偶数的方案数。显然可以 DP,发现我们需要枚举以 i 结尾的所有回文串,根据回文串的相关知识,我们知道这等价于 i 这个前缀最长回文串的 border。而我们又知道 border 形成 log 段等差数列,容易想到每一段一起转移。体现在回文自动机上就是沿着 fa 跳会形成公差为负的等差数列。我们需要在回文自动机上维护 x 和 fa[x] 的长度差,和 nxt 表示下一段等差数列的第一项。g 维护从 x 出发沿 fa 跳的这段等差数列最后一次出现的 f 的和。
#include
#define ll long long
#define pb push_back
#define fir first
#define sec second
#define ld long double
using namespace std;
const int N=1000010,mod=1e9+7;
typedef pair P;
int ch[N][26],fa[N],g[N],f[N],tot,lst,nxt[N],l[N],s[N],dif[N],pos[N];
char t[N];
int read()
{
int x=0;char c=getchar(),flag='+';
while(!isdigit(c)) flag=c,c=getchar();
while(isdigit(c)) x=x*10+c-'0',c=getchar();
return flag=='-'?-x:x;
}
int Insert(int x,int c)
{
int p=lst;
while(s[x-l[p]-1]^c) p=fa[p];
if(!ch[p][c])
{
int now=++tot,q=fa[p];
l[now]=l[p]+2;
while(s[x-l[q]-1]^c) q=fa[q];
int t=ch[q][c];
dif[now]=l[now]-l[t];
fa[now]=t;
if(dif[now]==dif[t]) nxt[now]=nxt[t];
else nxt[now]=t;
ch[p][c]=now;
}
return lst=ch[p][c];
}
int main()
{
scanf("%s",t+1);
int n=strlen(t+1);
for(int i=1;i<=n;i++)
{
if(i&1) s[i]=t[i+1>>1]-'a';
else s[i]=t[n-i/2+1]-'a';
}
s[0]=-1;
tot=1,l[1]=-1,fa[0]=1;
for(int i=1;i<=n;i++) pos[i]=Insert(i,s[i]);
f[0]=1;
for(int i=1;i<=n;i++)
{
for(int x=pos[i];x;x=nxt[x])
{
g[x]=f[i-l[nxt[x]]-dif[x]];
if(dif[x]==dif[fa[x]]) g[x]=g[x]+g[fa[x]]>=mod?g[x]+g[fa[x]]-mod:g[x]+g[fa[x]];
if(!(i&1)) f[i]=f[i]+g[x]>=mod?f[i]+g[x]-mod:f[i]+g[x];
}
}
cout<
722 F