感觉这种问题没有链分治解决不了的 , 如果有 , 那么再加一个线段树
每个重链结点记录下链中与他直接联系的黑点和白点的数目 , 这是很好维护的。 关键是询问。
线段树中的结点要记录这一段是否是同一种颜色 , 否则无法查询。 但这并不是最优的方式 , 最好是多维护两个量 , 分别是仅在重链上相同颜色从左端延续的长度和从右端延续的长度。 这样 findPath() 里就不用花 lg(n) 的时间去 judge() 了
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <string>
#include <vector>
#include <deque>
#include <stack>
#include <queue>
#include <set>
#include <map>
#include <algorithm>
using namespace std;
const int maxn = 1e5+1e2;
struct node{ int l , r , s; node(int l = 0 , int r = 0 , int s=1):l(l),r(r),s(s){} };
int n , dfsCnt , cnt;
int fa[maxn] , id[maxn] , reid[maxn] , bl[maxn] , Size[maxn] , s[maxn] , c[maxn];
vector<int> g[maxn];
void dfs(int u)
{
Size[u] = 1;
for(int i=0;i<g[u].size();i++)
{
int t = g[u][i];
if(t == fa[u]) continue;
fa[t] = u;
dfs(t);
Size[u] += Size[t];
}
}
void dfs(int u , int num)
{
reid[id[u] = ++dfsCnt] = u;
bl[u] = num;
s[num]++;
int mx = 0 , w;
for(int i=0;i<g[u].size();i++)
{
int t = g[u][i];
if(t == fa[u]) continue;
if(mx < Size[t]) mx = Size[t] , w = t;
}
if(mx) dfs(w, num);
for(int i=0;i<g[u].size();i++)
{
int t = g[u][i];
if(t == fa[u] || t == w) continue;
dfs(t, t);
}
}
node seg[maxn*4];
int ls[maxn*4] , rs[maxn*4] , root[maxn] , bw[maxn][2];
void maintain(int o , int x)
{
seg[o].l = seg[o].r = bw[x][c[x]] + 1;
}
node merge(node& a , node& b , bool same)
{
node res = node(a.l , b.r);
if(a.s && same) res.l = a.l + b.l;
if(b.s && same) res.r = b.r + a.r;
res.s = same && a.s && b.s;
return res;
}
void build(int o , int l , int r)
{
if(l==r)
{
int x = reid[l];
for(int i=0;i<g[x].size();i++)
{
int t = g[x][i];
if(t == fa[x] || bl[t] == bl[x]) continue;
build(root[t] = ++cnt, id[t], id[t]+s[t]-1);
bw[x][c[t]] += seg[root[t]].l;
}
maintain(o, x);
}
else
{
int mid = (l+r)/2;
build(ls[o] = ++cnt, l, mid);
build(rs[o] = ++cnt, mid+1, r);
seg[o] = merge(seg[ls[o]], seg[rs[o]], c[reid[mid]] == c[reid[mid+1]]);
}
}
bool judge(int o , int l , int r , int L , int R)
{
if(l==r) return true;
else
{
int mid = (l+r)/2 , res = 1;
if(L <= mid) res &= judge(ls[o], l, mid, L, R);
if(R > mid) res &= judge(rs[o], mid+1, r, L, R);
if(L <= mid && R > mid) res &= c[reid[mid]] == c[reid[mid+1]];
return res;
}
}
deque<int> pat;
void findPath(int x , bool noted = false)
{
pat.clear();
while(x)
{
int f = bl[x];
if(noted)
{
if(!judge(root[f], id[f], id[f]+s[f]-1, id[f], id[x]) || c[fa[f]] != c[f] || !fa[f])
{
pat.push_front(x);
pat.push_front(f);
break;
}
}
else
{
pat.push_front(x);
pat.push_front(f);
}
x = fa[f];
}
}
void modify(int o , int l , int r , int i)
{
if(l==r)
{
int x = reid[l];
if(i+2 < pat.size())
{
int t = pat[i+1];
bw[x][c[t]] -= seg[root[t]].l;
modify(root[t], id[t], id[t]+s[t]-1, i+2);
bw[x][c[t]] += seg[root[t]].l;
}
else c[x] = 1-c[x];
maintain(o, x);
}
else
{
int mid = (l+r)/2;
if(id[pat[i]] <= mid) modify(ls[o], l, mid, i);
else modify(rs[o], mid+1, r, i);
seg[o] = merge(seg[ls[o]], seg[rs[o]], c[reid[mid]] == c[reid[mid+1]]);
}
}
int query(int o , int l , int r , int i , bool& ln , bool& rn)
{
if(l==r) return seg[o].l;
else
{
int mid = (l+r)/2 , res = 0;
if(id[pat[i]] <= mid)
{
res += query(ls[o], l, mid, i , ln , rn);
if(rn && c[reid[mid+1]] == c[reid[mid]]) res += seg[rs[o]].l;
rn = rn && seg[rs[o]].s && c[reid[mid+1]] == c[reid[mid]];
}
else
{
res += query(rs[o], mid+1, r, i , ln , rn);
if(ln && c[reid[mid+1]] == c[reid[mid]]) res += seg[ls[o]].r;
ln = ln && seg[ls[o]].s && c[reid[mid+1]] == c[reid[mid]];
}
return res;
}
}
int main(int argc, char *argv[]) {
cin>>n;
for(int i=1;i<n;i++)
{
int a ,b;
scanf("%d%d" , &a , &b);
g[a].push_back(b);
g[b].push_back(a);
}
dfs(1);
dfs(1, 1);
build(root[1] = ++cnt, 1, s[1]);
int x , y , Q;
bool T1 , T2;
cin>>Q;
while(Q--)
{
scanf("%d%d" , &x , &y);
if(x)
{
findPath(y);
modify(1, 1, s[1], 1);
}
else
{
findPath(y , true);
int now = pat[0];
printf("%d\n" , query(root[now],id[now] ,id[now]+s[now]-1 , 1 , T1=1 , T2=1));
}
}
return 0;
}