【BJOI2017】树的难题【点分治】【线段树】

传送门
传送门

题意:给一棵树,树上有颜色,每种颜色有权值,定义一条路径的权值为所有颜色相同段的权值之和,求长度在 [ L , R ] [L,R] [L,R]中的路径的最大权值。

数据范围:暴力过不了

显然是个点分治

对于分治中心考虑过中心的路径贡献的答案

(以下的“子树”指根直接与分治中心相连的子树)

把分治中心作为根,定义一棵子树的颜色为与根(即分治中心)相连的边的颜色

显然可以算出两边的贡献,如果两边颜色一样,再减掉这个权值

有个单调队列的神仙做法,看不懂

所以自己 y y yy yy了线段树做法(题解也有)

把所有子树按颜色排序,把相同颜色放一起

维护两棵线段树,表示长度为 x x x的路径的最大权值,一棵维护相同颜色,一棵维护不同颜色

遍历所有子树的所有节点,设当前节点的深度为 d i s dis dis,区间查询两棵线段树中 [ L − d i s , R − d i s ] [L-dis,R-dis] [Ldis,Rdis]的答案。

遍历完一棵子树后,将它们加入同色线段树。

当进入不同颜色时,把前一个颜色的点插入异色线段树,并给同色线段树打清除标记

复杂度大概 O ( N l o g N 2 ) O(Nlog_N^2) O(NlogN2)

#include 
#include 
#include 
#include 
#include 
#define MAXN 200005
#define MAXM 400005
#define re register
#define INF 0x3f3f3f3f
using namespace std;
class fast_input {
private:
    static const int SIZE = 1 << 15 | 1;
    char buf[SIZE], *front, *back;

    void Next(char &c) {
        if(front == back) back = (front = buf) + fread(buf, 1, SIZE, stdin);
        c = front == back ? (char)EOF : *front++;
    }

public :
    template<class T>void operator () (T &x) {
        char c, f = 1;
        for(Next(c); !isdigit(c); Next(c)) if(c == '-') f = -1;
        for(x = 0; isdigit(c); Next(c)) x = x * 10 + c - '0';
        x *= f;
    }
    void operator () (char &c, char l = 'a', char r = 'z') {
        for(Next(c); c > r || c < l; Next(c)) ;
    }
}input;

struct edge{int u,v,w;}e[MAXM];
int head[MAXN],nxt[MAXM],cnt;
void addnode(int u,int v,int w)
{
	e[++cnt]=(edge){u,v,w};
	nxt[cnt]=head[u];
	head[u]=cnt;
}
int cost[MAXN],n,m,L,R;
bool cut[MAXN];
int siz[MAXN],sum[MAXN],maxp[MAXN],root;
void findroot(int u,int f,int sum)
{
	siz[u]=1,maxp[u]=0;
	for (re int i=head[u];i;i=nxt[i])
		if (e[i].v!=f&&!cut[e[i].v])
		{
			findroot(e[i].v,u,sum);
			siz[u]+=siz[e[i].v];
			maxp[u]=max(maxp[u],siz[e[i].v]);
		}
	if (sum-siz[u]>maxp[u]) maxp[u]=sum-siz[u];
	if (maxp[u]<maxp[root]) root=u;
}
struct SGT
{
	#define lc p<<1
	#define rc p<<1|1
	int mx[MAXN<<2],cl[MAXN<<2];
	inline void pushcl(int p){mx[p]=-INF;cl[p]=1;}
	inline void pushdown(int p){if (cl[p]) pushcl(lc),pushcl(rc),cl[p]=0;}
	void build(int p,int l,int r)
	{
		mx[p]=-INF;
		if (l==r) return;
		int mid=(l+r)>>1;
		build(lc,l,mid),build(rc,mid+1,r);
	}
	void modify(int p,int l,int r,int k,int v)
	{
		if (l==r) return (void)(mx[p]=max(mx[p],v));
		pushdown(p);
		int mid=(l+r)>>1;
		if (k<=mid) modify(lc,l,mid,k,v);
		else modify(rc,mid+1,r,k,v);
		mx[p]=max(mx[lc],mx[rc]);
	}
	int query(int p,int l,int r,int ql,int qr)
	{
		pushdown(p);
		if (ql<=l&&r<=qr) return mx[p];
		if (r<ql||qr<l) return -INF;
		int mid=(l+r)>>1;
		return max(query(lc,l,mid,ql,qr),query(rc,mid+1,r,ql,qr));
	}
	inline void clear(){pushcl(1);}
}sam,dif;
int ans=-INF,maxlen;
int nod[MAXN],col[MAXN];
inline bool cmp(const int& a,const int& b){return col[a]<col[b];}
int tot;
int len[MAXN],val[MAXN],num,_len[MAXN],_val[MAXN],_num;
int dfs(int u,int f,int l,int v,int las_col)
{
	int s=1;
	if (l>=L&&l<=R) ans=max(ans,v);
	maxlen=max(maxlen,l);
	len[++num]=l,val[num]=v;
	_len[++_num]=l,_val[_num]=v;
	for (int i=head[u];i;i=nxt[i])
		if (e[i].v!=f&&!cut[e[i].v])
			s+=dfs(e[i].v,u,l+1,e[i].w==las_col? v:v+cost[e[i].w],e[i].w);
	return s;
}
void calc()
{
	maxlen=0;
	for (re int i=head[root];i;i=nxt[i])
		if (!cut[e[i].v])
			nod[++tot]=e[i].v,col[e[i].v]=e[i].w;
	sort(nod+1,nod+tot+1,cmp);
	for (re int i=1;i<=tot;i++)
	{
		if (col[nod[i]]!=col[nod[i-1]])
		{
			for (int u=1;u<=_num;u++) dif.modify(1,1,n,_len[u],_val[u]);
			sam.clear(),_num=0;
		}
		sum[nod[i]]=dfs(nod[i],0,1,cost[col[nod[i]]],col[nod[i]]);
		for (re int u=1;u<=num;u++)
		{
			ans=max(ans,val[u]+sam.query(1,1,n,L-len[u],R-len[u])-cost[col[nod[i]]]);
			ans=max(ans,val[u]+dif.query(1,1,n,L-len[u],R-len[u]));
		}
		for (re int u=1;u<=num;u++) sam.modify(1,1,n,len[u],val[u]);
		num=0;
	}
	sam.clear(),dif.clear();
	tot=num=_num=0;
}
void solve()
{
	cut[root]=true;
	calc();
	for (re int i=head[root];i;i=nxt[i])
		if (!cut[e[i].v])
		{
			maxp[root=0]=INF;
			findroot(e[i].v,0,sum[e[i].v]);
			sum[e[i].v]=0,solve();
		}
}
int main()
{
	input(n),input(m),input(L),input(R);
	for (re int i=1;i<=m;i++) input(cost[i]);
	for (re int i=1;i<n;i++)
	{
		int u,v,w;
		input(u),input(v),input(w);
		addnode(u,v,w),addnode(v,u,w);
	}
	sam.build(1,1,n),dif.build(1,1,n);
	maxp[root=0]=INF;
	findroot(1,0,n);
	solve();
	printf("%d\n",ans);
	return 0;
}

你可能感兴趣的:(【BJOI2017】树的难题【点分治】【线段树】)