回文树网上比较经典的模板为:Palindromic Tree——回文树【处理一类回文串问题的强力工具】,个人见过的也多为这个,网上还有一个邻接表的模板,较为省空间。原地址为回文树(附模板题URAL-1960)。回文树的有关视频讲解成电在B站的算法讲堂也有。
例题Ural1960:题目链接:http://acm.timus.ru/problem.aspx?space=1&num=1960
题意:给定一个字符串,求每次插入一个字符输出当前的不同本质子串。
题解:毕竟裸的回文树。每次创建一个新节点就说明多了一个不同的字符串,否则就不会新建结点,考虑一开始建立了0和-1的结点,所以每次add后的p-2就是所求。
AC代码:
/*
* @Author: 王文宇
* @Date: 2018-09-03 16:43:50
* @Last Modified by: 王文宇
* @Last Modified time: 2018-09-03 21:26:55
*/
#include
using namespace std;
const int maxn = 1e5+7;
const int siz = 26;
#define _for(i,a,b) for(int i=a;i<=b;i++)
struct PAM
{
int next[maxn][siz];
int fail[maxn];
int num[maxn];
int len[maxn];
int cnt[maxn];
int n,p,last,S[maxn];
int sum;
int newnode(int x)
{
_for(i,0,siz)next[p][i]=0;
len[p]=x;
num[p]=0;
cnt[p]=0;
return p++;
}
void init()
{
sum=0;
p=0;
newnode(0);
newnode(-1);
last=n=0;
fail[0]=1;
S[0]=-1;
}
int get_fail(int x)
{
while(S[n-len[x]-1]!=S[n])x=fail[x];
return x;
}
void add(int x)
{
int c = x-'a';
S[++n]=c;
int cur = get_fail(last);
if(!next[cur][c])
{
int now = newnode(len[cur]+2);
fail[now]=next[get_fail(fail[cur])][c];
next[cur][c]=now;
num[now]=num[fail[now]]+1;
sum++;
}
last = next[cur][c];
cnt[last]++;
}
void count()
{
for(int i=p-1;i>=0;i--)cnt[fail[i]]+=cnt[i];
}
}pam;
int main(int argc, char const *argv[])
{
char s[maxn];
cin>>s;
int l = strlen(s)-1;
pam.init();
_for(i,0,l)
{
pam.add(s[i]);
cout<
BZOJ2565: 最长双回文串:链接:https://www.lydsy.com/JudgeOnline/problem.php?id=2565
题意:求一个最长子串,使得子串有一个位置使得位置前和位置后的子串都是回文串。
题解:正向建回文树,建的时候对于当前结点i,从1-i中的最长后缀为len[last],记录下来然后反向再统计一遍取最大值即可。
AC代码:
/*
* @Author: 王文宇
* @Date: 2018-09-03 16:43:50
* @Last Modified by: 王文宇
* @Last Modified time: 2018-09-04 00:47:55
*/
#include
using namespace std;
const int maxn = 1e5+7;
const int siz = 26;
#define _for(i,a,b) for(int i=a;i<=b;i++)
int sum[maxn];
struct PAM
{
int next[maxn][siz];
int fail[maxn];
int num[maxn];
int len[maxn];
int cnt[maxn];
int n,p,last,S[maxn];
int sum;
int newnode(int x)
{
_for(i,0,siz)next[p][i]=0;
len[p]=x;
num[p]=0;
cnt[p]=0;
return p++;
}
void init()
{
sum=0;
p=0;
newnode(0);
newnode(-1);
last=n=0;
fail[0]=1;
S[0]=-1;
}
int get_fail(int x)
{
while(S[n-len[x]-1]!=S[n])x=fail[x];
return x;
}
int add(int x)
{
int c = x-'a';
S[++n]=c;
int cur = get_fail(last);
if(!next[cur][c])
{
int now = newnode(len[cur]+2);
fail[now]=next[get_fail(fail[cur])][c];
next[cur][c]=now;
num[now]=num[fail[now]]+1;
sum++;
}
last = next[cur][c];
cnt[last]++;
return len[last];
}
void count()
{
for(int i=p-1;i>=0;i--)cnt[fail[i]]+=cnt[i];
}
}pam;
int main(int argc, char const *argv[])
{
char s[maxn];
cin>>s;
int l = strlen(s)-1;
pam.init();
_for(i,0,l)
{
sum[i]=pam.add(s[i]);
}
pam.init();
s[l+1]='z'+1;
int ans = 0;
int k = 1;
_for(i,0,l-1)
{
ans = max(ans,sum[l-i-1]+pam.add(s[l-i]));
}
ans = max(ans,pam.add(s[0]));
cout<
BZOJ2160: 拉拉队排练:链接:https://www.lydsy.com/JudgeOnline/problem.php?id=2160
题意:给定n个字符组成的字符串和一个整数k,字符串中每一个奇数回文子串为一个集合,按子串长度排序,求前k个子串的个数的乘积为多少。
题解:把每一个长度为奇数的回文子串找出来排序,然后快速幂求解即可。主要爆int。
AC代码:
/*
* @Author: 王文宇
* @Date: 2018-09-03 16:43:50
* @Last Modified by: 王文宇
* @Last Modified time: 2018-09-04 02:32:49
*/
#include
using namespace std;
const int maxn = 1e6+7;
const int siz = 26;
const int mod = 19930726;
typedef long long ll;
#define _for(i,a,b) for(int i=a;i<=b;i++)
struct PAM
{
int next[maxn][siz];
int fail[maxn];
int num[maxn];
int len[maxn];
int cnt[maxn];
int n,p,last,S[maxn];
int sum;
int newnode(int x)
{
_for(i,0,siz)next[p][i]=0;
len[p]=x;
num[p]=0;
cnt[p]=0;
return p++;
}
void init()
{
sum=0;
p=0;
newnode(0);
newnode(-1);
last=n=0;
fail[0]=1;
S[0]=-1;
}
int get_fail(int x)
{
while(S[n-len[x]-1]!=S[n])x=fail[x];
return x;
}
int add(int x)
{
int c = x-'a';
S[++n]=c;
int cur = get_fail(last);
if(!next[cur][c])
{
int now = newnode(len[cur]+2);
fail[now]=next[get_fail(fail[cur])][c];
next[cur][c]=now;
num[now]=num[fail[now]]+1;
sum++;
}
last = next[cur][c];
cnt[last]++;
return len[last];
}
void count()
{
for(int i=p-1;i>=0;i--)cnt[fail[i]]+=cnt[i];
}
}pam;
ll Pow(int a,int b)
{
ll ans = 1;
while(b)
{
if(b&1)ans = ans*a%mod;
a = 1LL*a*a%mod;
b>>=1;
}
return ans;
}
int main(int argc, char const *argv[])
{
int n;
ll k;
cin>>n>>k;
char s[maxn];
cin>>s;
vector >E;
pam.init();
_for(i,0,n-1)
{
pam.add(s[i]);
}
pam.count();
for(int i=2;i<=pam.p;i++)
{
if(pam.len[i]%2==1)E.push_back(make_pair(pam.len[i],i));
}
sort(E.begin(),E.end());
reverse(E.begin(),E.end());
int l = E.size();
ll now = 0;
int ok = 0;
ll ans = 1;
_for(i,0,l-1)
{
//cout<