回文自动机模板

回文自动机每个节点代表了一个回文串

能求出来的有:

1 本质不同的回文串的个数(tot - 1)

2 每种回文串出现的次数(cnt数组)

3 每种回文串的长度 (len数组)

4 以当前节点为后缀的回文串个数 (sed数组)

5 每个回文串在原串出现的位置 (record数组)

#include 
using namespace std;
const int N = 3e5 + 10;
typedef unsigned long long ull;
typedef long long ll;
char s[N];
int record[N];
struct Palindrome_tree{
	int nxt[N][26];
	int fail[N]; // 当前节点最长回文后缀的节点
	int len[N]; // 当前节点表示的回文串的长度
	int cnt[N]; // 当前节点回文串的个数, 在getcnt后可得到全部
	int sed[N]; // 以当前节点为后缀的回文串的个数
	int tot; // 节点个数
	int last; // 上一个节点
	void init()
	{
		tot = 0;
		memset(fail, 0, sizeof fail);
		memset(cnt, 0, sizeof cnt);
		memset(sed, 0, sizeof sed);
		memset(len, 0, sizeof len);
		memset(nxt, 0, sizeof nxt);
	}
	void build()
	{
		len[0] = 0, len[1] = -1; // 0为偶数长度根, 1为奇数长度根
		tot = 1, last = 0;
		fail[0] = 1;
	}
	int getfail(char *s, int x, int n)
	{
		while (s[n - len[x] - 1] != s[n]) // 比较x节点回文串新建两端是否相等
			x = fail[x]; // 若不同, 再比较x后缀回文串两端
		return x;
	}
	void insert(char* s, int n)
	{
		for (int i = 0; i < n; i++)
		{
			int c = s[i] - 'a';
			int p = getfail(s, last, i);// 得到第i个字符可以加到哪个节点的两端形成回文串 
			if (!nxt[p][c])
			{
				tot++;
				len[tot] = len[p] + 2;  // 在p节点两端添加两个字符
				fail[tot] = nxt[getfail(s, fail[p], i)][c]; //tot点的后缀回文,可以由上一个节点的后缀回文尝试得到  
				sed[tot] = sed[fail[tot]] + 1; // 以当前节点为结尾的回文串个数
				nxt[p][c] = tot; // 新建节点
			}
			last = nxt[p][c]; // 当前节点成为上一个节点
			cnt[last]++; //当前节点回文串++
                        record[last] = i;
		}
	}
	void get_cnt()
	{
		for (int i = tot; i > 0; i--)
			cnt[fail[i]] += cnt[i];
		//fail[i] 的节点 为 i 节点的后缀回文串, 所以个数相加
	}
}pdt;
int main(){
	while (~scanf("%s", s))
	{
		pdt.init();
		pdt.build();
		pdt.insert(s, strlen(s));
		pdt.get_cnt();
	}
	return 0;
}

对于init,如果组数过多,可能会导致超时(例如 UVALive - 7041 The Problem to Slow Down You )

所以有一个newnode的版本,即每次新建节点时,再初始化信息。

struct Palindrome_tree
{
	int nxt[N][26];
	int fail[N];
	int len[N];
	int cnt[N];
	int sed[N];
	int tot;
	int last;
	void init()
	{
		tot = 0;
		memset(fail, 0, sizeof fail);
		memset(cnt, 0, sizeof cnt);
		memset(sed, 0, sizeof sed);
		memset(len, 0, sizeof len);
		memset(nxt, 0, sizeof nxt);
	}
	int newnode(int lenx)
	{
		for (int i = 0; i < 26; i++)
			nxt[tot][i] = 0;
		sed[tot] = cnt[tot] = 0;
		len[tot] = lenx;
		return tot;
	}
	void build()
	{
		tot = 0;
		newnode(0);
		tot = 1, last = 0;
		newnode(-1);
		fail[0] = 1;
	}
	int getfail(char* s, int x, int n)
	{
		while (s[n - len[x] - 1] != s[n])
			x = fail[x];
		return x;
	}
	void insert(char* s, int n)
	{
		for (int i = 1; i <= n; i++)
		{
			int c = s[i] - 'a';
			int p = getfail(s, last, i);
			if (!nxt[p][c])
			{
				tot++;
				newnode(len[p] + 2);
				fail[tot] = nxt[getfail(s, fail[p], i)][c];
				sed[tot] = sed[fail[tot]] + 1;
				nxt[p][c] = tot;
			}
			last = nxt[p][c];
			cnt[last]++;
		}
	}
	void get_cnt()
	{
		for (int i = tot; i > 0; i--)
			cnt[fail[i]] += cnt[i];
	}
}pdt;

 

 

你可能感兴趣的:(字符串-回文自动机)