hdu 4897 Little Devil I 树剖(题如其名..)

题意:最开始树上的边全是0,在树上有三种操作,1 使某条路径上的边翻转(0->1, 1->0)。2 翻转某条路径上的与 a 和 b 之间的简单路径相邻的所有边。3查询某条路径上有多少个1。

详参考http://blog.csdn.net/u013368721/article/details/39338679

先是树剖无疑,然后用线段树维护的时候设置三个遍历:sum记录1的个数,flip记录该点翻转的情况,mark标记以该点为父亲节点时,与该点相连的轻边的翻转情况(不包括其父亲节点连向它的轻边)。也就是说,flip标记的是这个点本身,mark记录的是边的情况- -

注意一下几点:

1:第一步更新就是普通的区间更新没什么。由两条重链必定是由一条轻边连接的特点可知,当进行第二种更新时,两端都在重链上,那么直接给这一段的区间mark一下,但是在路径上的轻边势必会受到它的父亲节点mark标记的影响,那么我们将以该轻边的起点翻转一下,来抵消其父亲节点mark的影响。

2:注意第二步更新的时候,要特别处理链头和链尾,即 pos[ pre[x] ] 和 pos[ son[x] ],特别拉出来处理

3:query的时候,如果是当前两端是在重链上,那么直接res += 该端区间的1,注意到轻边是要受到父亲节点mark标记的影响的,所以在查询时要特殊处理

4:处理好细节。。。

5:处理好细节。。。

....

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

using namespace std;

#define lson l, mid, rt << 1
#define rson mid + 1, r, rt << 1 | 1
#define ls rt << 1
#define rs rt << 1 | 1
#define pi acos(-1.0)
#define eps 1e-8
typedef long long ll;
const int inf = 0x3f3f3f3f;
const int N = 100010;

struct node{
	int nxt, v;
}e[N << 1];

int tot, cnt;
int pos[N];
int sz[N];
int dep[N];
int pre[N];
int son[N];
int head[N];
int top[N];
int n, q;

void init()
{
	cnt = tot = 0;
	sz[0] = dep[1] = 0;
	memset( head, -1, sizeof( head ) );
}

void add( int u, int v )
{
	e[cnt].v = v;
	e[cnt].nxt = head[u];
	head[u] = cnt++;

	e[cnt].v = u;
	e[cnt].nxt = head[v];
	head[v] = cnt++;
}

void dfs( int u )
{
	sz[u] = 1;
	son[u] = 0;
	for( int i = head[u]; ~i; i = e[i].nxt )
	{
		int v = e[i].v;
		if( v == pre[u] )
			continue;
		pre[v] = u;
		dep[v] = dep[u] + 1;
		dfs( v );
		sz[u] += sz[v];
		if( sz[v] > sz[ son[u] ] )
			son[u] = v;
	}
}

void rebuild( int u, int anc )
{
	pos[u] = ++tot;
	top[u] = anc;
	if( son[u] )
		rebuild( son[u], anc );
	for( int i = head[u]; ~i; i = e[i].nxt )
	{
		int v = e[i].v;
		if( v != pre[u] && v != son[u] )
			rebuild( v, v );
	}
}

struct seg{
	int l, r, x;
	bool flip, mark;
}tr[N << 2];

void pushup( int rt )
{
	tr[rt].x = tr[ls].x + tr[rs].x;
}

void down( int rt )
{
	if( tr[rt].flip )
	{
		tr[ls].flip ^= 1;
		tr[rs].flip ^= 1;
		int mid = ( tr[rt].r + tr[rt].l ) >> 1;
		tr[ls].x = mid - tr[rt].l + 1 - tr[ls].x;
		tr[rs].x = tr[rt].r - mid - tr[rs].x;
		tr[rt].flip = 0;
	}
	if( tr[rt].mark )
	{
		tr[ls].mark ^= 1;
		tr[rs].mark ^= 1;
		tr[rt].mark = 0;
	}
}

void build( int l, int r, int rt )
{
	tr[rt].l = l;
	tr[rt].r = r;
	tr[rt].x = tr[rt].flip = tr[rt].mark = 0;
	if( l == r )
		return;
	int mid = ( l + r ) >> 1;
	build( lson );
	build( rson );
	pushup( rt );
}

void Flip( int l, int r, int rt )
{
	if( l <= tr[rt].l && tr[rt].r <= r )
	{
		tr[rt].flip ^= 1;
		tr[rt].x = ( tr[rt].r - tr[rt].l + 1 ) - tr[rt].x;
		return;
	}
	down( rt );
	int mid = ( tr[rt].l + tr[rt].r ) >> 1;
	if( r <= mid )
		Flip( l, r, ls );
	else if( l > mid )
		Flip( l, r, rs );
	else
	{
		Flip( lson );
		Flip( rson );
	}
	pushup( rt );
}

int query( int l, int r, int rt )
{
	if( l <= tr[rt].l && tr[rt].r <= r )
		return tr[rt].x;
	down( rt );
	int mid = ( tr[rt].l + tr[rt].r ) >> 1;
	if( r <= mid )
		return query( l, r, ls );
	else if( l > mid )
		return query( l, r, rs );
	else
		return query( lson ) + query( rson );
}

void Mark( int l, int r, int rt )
{
	if( l <= tr[rt].l && tr[rt].r <= r )
	{
		tr[rt].mark ^= 1;
		return;
	}
	down( rt );
	int mid = ( tr[rt].l + tr[rt].r ) >> 1;
	if( r <= mid )
		Mark( l, r, ls );
	else if( l > mid )
		Mark( l, r, rs );
	else
	{
		Mark( lson );
		Mark( rson );
	}
	pushup( rt );
}

void update1( int x, int y )
{
	while( top[x] != top[y] )
	{
		int f1 = top[x], f2 = top[y];
		if( dep[f1] > dep[f2] )
		{
			Flip( pos[f1], pos[x], 1 );
			x = pre[f1];
		}
		else
		{
			Flip( pos[f2], pos[y], 1 );
			y = pre[f2];
		}
	}
	if( x == y )
		return;
	if( dep[x] > dep[y] )
		swap( x, y );
	Flip( pos[x]+1, pos[y], 1 );
}

void update2( int x, int y )
{
	while( top[x] != top[y] )
	{
		int f1 = top[x], f2 = top[y];
		if( dep[f1] > dep[f2] )
		{
			Mark( pos[f1], pos[x], 1 );
			if( son[x] )
				Flip( pos[x]+1, pos[x]+1, 1 );
			Flip( pos[f1], pos[f1], 1 );
			x = pre[f1];
		}
		else
		{
			Mark( pos[f2], pos[y], 1 );
			if( son[y] )
				Flip( pos[y]+1, pos[y]+1, 1 );
			Flip( pos[f2], pos[f2], 1 );
			y = pre[f2];
		}
	}
	if( dep[x] > dep[y] )
		swap( x, y );
	Mark( pos[x], pos[y], 1 );
	if( pre[x] )
		Flip( pos[x], pos[x], 1 );
	if( son[y] )
		Flip( pos[y]+1, pos[y]+1, 1 );
}

bool find_mark( int pos, int rt )
{
	if( tr[rt].l == tr[rt].r && tr[rt].l == pos )
		return tr[rt].mark;
	down( rt );
	int mid = ( tr[rt].l + tr[rt].r ) >> 1;
	if( pos <= mid )
		return find_mark( pos, ls );
	else
		return find_mark( pos, rs );
}

int query1( int x, int y )
{
	int res = 0;
	while( top[x] != top[y] )
	{
		int f1 = top[x], f2 = top[y];
		if( dep[f1] > dep[f2] )
		{
			if( f1 != x )
				res += query( pos[f1]+1, pos[x], 1 );
			res += ( query( pos[f1], pos[f1], 1) ^ find_mark( pos[pre[f1]], 1 ));
			x = pre[f1];
		}
		else
		{
			if( f2 != y )
				res += query( pos[f2]+1, pos[y], 1 );
			res += ( query( pos[f2], pos[f2], 1) ^ find_mark( pos[pre[f2]], 1 ));
			y = pre[f2];
		}
	}
	if( x == y )
		return res;
	if( dep[x] > dep[y] )
		swap( x, y );
	res += query( pos[x] + 1, pos[y], 1 );
	return res;
}

int main()
{
	int tt;
	scanf("%d", &tt);
	while( tt-- )
	{
		scanf("%d", &n);
		init();
		int u, v, op;
		for( int i = 1; i < n; ++i )
		{
			scanf("%d%d", &u, &v);
			add( u, v );
		}
		dfs( 1 );
		rebuild( 1, 1 );
		build( 1, n, 1 );
		scanf("%d", &q);
		while( q-- )
		{
			scanf("%d%d%d", &op, &u, &v);
			if( op == 1 )
				update1( u, v );
			else if( op == 2 )
				update2( u, v );
			else
			{
				int ans = query1( u, v );
				printf("%d\n", ans);
			}
		}
	}
	return 0;
}

各种细节要处理,,wa了两天心累

你可能感兴趣的:(树剖,HDU)