[NOI2016]优秀的拆分 后缀自动机 树上启发式合并 线段树

[NOI2016]优秀的拆分

题目传送门
luogu
bzoj

分析

这道题不管采用Hash,后缀数组还是自动机,网上大部分的题解都采用了关键点+调和级数这个操作。本蒟蒻想不到关键点这个操作,所以采用的是一种较为繁琐的 O ( n l o g 2 ) O(nlog^2) O(nlog2)做法。

首先肯定将问题转化成对于每个 i i i求以 i i i为边界的 A A AA AA结构个数,当然前缀后缀分别求一遍,以下默认是前缀。

考虑形式化这个问题,对于某个前缀 i i i,求所有的前缀 j ( j < i ) j(j<i) j(j<i),使得 i , j i,j i,j的最长公共后缀的长度大于 j − i j-i ji,也就是 ∣ { j ∣ j < i , i − j ≤ ∣ L c s ( S 1 , j , S 1 , i ) ∣ } ∣ |\{j|j<i,i-j\le |Lcs(S_{1,j},S_{1,i})|\}| {jj<i,ijLcs(S1,j,S1,i)}

对于后缀前缀的问题,我们一般将他们放到后缀自动机的 p a r e n t parent parent树上考虑,由于后缀自动机的 p a r e n t parent parent树相当于是将每个前缀逆序插入 T r i e Trie Trie,所以某两个前缀的 L c s Lcs Lcs对应的就是他们 p a r e n t parent parent树上的 L c a Lca Lca

所以转化成树上给若干个关键点(对应的是字符串的前缀),对于每个 x x x,求 { y ∣ m x x − m x y ≤ m x l c a ( x , y ) } \{y|mx_x-mx_y\le mx_{lca(x,y)}\} {ymxxmxymxlca(x,y)}

考虑采用树上启发式合并,对于每个节点建立一颗动态开点线段树,我们让父亲继承重儿子的线段树,把其他子树中的线段树的节点暴力插入合并。

插入某个节点的时候,考虑线段树内的节点对其的贡献,和它对线段树内节点的贡献。前者用一个区间询问即可,否则在线段树上打标记,但这个线段树要被拆开的时候再把暴力标记推下去贡献到答案上即可。

每个节点插入的之后其所在线段树大小翻倍,所以之多插入 l o g log log次,总复杂度 O ( n l o g 2 ) O(nlog^2) O(nlog2)

代码

#include
const int N = 6e4 + 10, T = 1e6 + 10;
int ri() {
	char c = getchar(); int x = 0, f = 1; for(;c < '0' || c > '9'; c = getchar()) if(c == '-') f = -1;
	for(;c >= '0' && c <= '9'; c = getchar()) x = (x << 1) + (x << 3) - '0' + c; return x * f;
}
int nx[N], pr[N], fa[N], ch[N][26], st[N], rt[N], mx[N], ans1[N], ans2[N];
int ls[T], rs[T], cnt[T], tag[T], tp, *tot, top, n, last, sz;
bool val[N]; char s[N];
void Clear() {
	top = last = 1; pr[1] = 0;
	memset(ch[1], 0, sizeof(ch[1]));
	sz = 0;
	for(int i = 1;i <= n; ++i)
		tot[i] = 0;
}
void Push(int p) {
	if(tag[p]) {
		if(ls[p])
			tag[ls[p]] += tag[p];
		if(rs[p])
			tag[rs[p]] += tag[p];
		tag[p] = 0;
	}
}
void Get(int p, int L, int R) {
	if(L == R) {
		tot[L] += tag[p];
		st[++tp] = L;
		return ;
	}
	int m = L + R >> 1; Push(p);
	if(ls[p])
		Get(ls[p], L, m);
	if(rs[p])
		Get(rs[p], m + 1, R);
}
void Ins(int &p, int L, int R, int x) {
	if(!p) {p = ++sz; ls[p] = rs[p] = tag[p] = cnt[p] = 0;}
	++cnt[p]; if(L == R) return ;
	int m = L + R >> 1; Push(p);
	if(x <= m) Ins(ls[p], L, m, x);
	else Ins(rs[p], m + 1, R, x);
}
void Modify(int p, int L, int R, int st, int ed) {
	if(L == st && ed == R)
		return ++tag[p], void();
	int m = L + R >> 1; Push(p);
	if(st <= m && ls[p])
		Modify(ls[p], L, m, st, std::min(ed, m));
	if(ed > m && rs[p])
		Modify(rs[p], m + 1, R, std::max(st, m + 1), ed);
}
int Query(int p, int L, int R, int st, int ed) {
	if(L == st && ed == R)
		return cnt[p];
	int m = L + R >> 1, ans = 0;
	if(st <= m && ls[p])
		ans += Query(ls[p], L, m, st, std::min(ed, m));
	if(ed > m && rs[p])
		ans += Query(rs[p], m + 1, R, std::max(st, m + 1), ed);
	return ans;
}
void Extend(int c) {
	int p = last, np = last = ++top; 
	memset(ch[np], 0, sizeof(ch[np])); pr[np] = 0;
	mx[np] = mx[p] + 1; rt[np] = 0; val[np] = true;
	for(;p && !ch[p][c]; p = fa[p])
		ch[p][c] = np;
	if(!p) fa[np] = 1;
	else {
		int q = ch[p][c];
		if(mx[q] == mx[p] + 1) 
			fa[np] = q;
		else {
			int nq = ++top; mx[nq] = mx[p] + 1;
			memcpy(ch[nq], ch[q], sizeof(ch[nq]));
			rt[nq]  = 0; val[nq] = false; pr[nq] = 0;
			fa[nq] = fa[q];
			fa[q] = fa[np] = nq;
			for(;ch[p][c] == q; p = fa[p])
				ch[p][c] = nq;
		}
	}
}
void Merge(int rt, int x, int c) {
	if(x - c)
		tot[x] += Query(rt, 1, n, x - c, x);
	Modify(rt, 1, n, x, std::min(n, x + c));
}
void Dfs(int u) {
	if(!pr[u])
		return Ins(rt[u], 1, n, mx[u]);
	int ds = 0;
	for(int i = pr[u]; i; i = nx[i]) {
		Dfs(i);
		if(cnt[rt[i]] > cnt[rt[ds]])
			ds = i;
	}
	rt[u] = rt[ds];
	for(int i = pr[u]; i; i = nx[i]) 
		if(i != ds){
			tp = 0; Get(rt[i], 1, n);
			for(int x = 1; x <= tp; ++x)
				Merge(rt[u], st[x], mx[u]);
			for(int x = 1;	x <= tp; ++x)
				Ins(rt[u], 1, n, st[x]);
		}
	if(val[u]) {
		Merge(rt[u], mx[u], mx[u]); 
		Ins(rt[u], 1, n, mx[u]);
	}
}
void Work() {
	Clear();
	for(int i = 1; i <= n; ++i)
		Extend(s[i] - 'a');
	for(int i = 2;i <= top; ++i)
		nx[i] = pr[fa[i]], pr[fa[i]] = i;
	Dfs(1); tp = 0; Get(rt[1], 1, n);
}
int main() {
	for(int T = ri(); T--; ) {
		scanf("%s", s + 1); n = strlen(s + 1);
		tot = ans1; Work();
		std::reverse(s + 1, s + n + 1);
		tot = ans2; Work();
		long long ans = 0;
		for(int i = 1;i <= n; ++i)
			ans += 1LL * ans1[i] * ans2[n - i];
		printf("%lld\n", ans);
	}
	return 0;
}

你可能感兴趣的:(数据结构-线段树&&树状数组)