[LOJ2339][虚树][边分治][树形DP]WC2018:通道

LOJ2339

44pts暴力就不用讲了

两棵树的做法似乎是个套路?先拆距离变成 d e p 1 [ x ] + d e p 1 [ y ] − 2 ∗ d e p 1 [ l c a 1 ( x , y ) ] + d i s 2 ( x , y ) dep1[x]+dep1[y]-2*dep1[lca1(x,y)]+dis2(x,y) dep1[x]+dep1[y]2dep1[lca1(x,y)]+dis2(x,y),然后就可以在第一棵树上从下到上枚举lca,消去lca的影响,然后剩下的部分就可以在第二棵树上拆点,拆出来两点的连边长度为对应的dep1,然后第二棵树上的对应点集直径就是答案,这里的对应点集指的是第一棵树中某个子树的所有点在第二棵树上对应编号的点组成的集合
但由于我们枚举了lca,所以对应的答案必须由lca的两个不同子树的答案拼起来,当然直径的合并是很简单的,这里就不讲了
要注意的是对答案有贡献的只有四种情况,而转移有六种情况

接下来考虑三棵树:能不能想个办法把它变成两棵树?当然可以——树分治
我们可以对第三棵树统计经过当前分治中心的路径的长度,它被分治中心分成了两部分,这两部分也和lca无关,我们可以直接丢到第二棵树上一并统计。
但是这两条路径需要来自不同子树,那么如果点分治就很麻烦,需要统计很多情况,那我们就用边分治处理,这样只有两个部分,我们分别统计dp值即可

最好是用RMQ—LCA

那这题就愉快的解决啦

Code(码到断手 ):

#include
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define ll long long
using namespace std;
inline ll read(){
	ll res=0,f=1;char ch=getchar();
	while(!isdigit(ch)) {if(ch=='-') f=-f;ch=getchar();}
	while(isdigit(ch)) {res=(res<<1)+(res<<3)+(ch^48);ch=getchar();}
	return res*f;
}
const int N=1e5+5;
const int INF=1e9;
struct G{
	int vis[N<<1],nxt[N<<1],head[N],tot;
	ll c[N<<1];
	inline void add(int x,int y,ll z=0){vis[++tot]=y;nxt[tot]=head[x];head[x]=tot;c[tot]=z;}
	inline void dbadd(int x,int y,ll z=0){add(x,y,z);add(y,x,z);}
	inline void clear(int l,int r){
		for(int i=l;i<=r;i++) head[i]=0;
		tot=l;
	}
}tr[3],itr;
struct G1{
	int vis[N<<3],nxt[N<<3],head[N<<2],tot;
	ll c[N<<3];
	inline void add(int x,int y,ll z=0){vis[++tot]=y;nxt[tot]=head[x];head[x]=tot;c[tot]=z;}
	inline void dbadd(int x,int y,ll z=0){add(x,y,z);add(y,x,z);}
	inline void clear(int l,int r){
		for(int i=l;i<=r;i++) head[i]=0;
		tot=l;
	}
}tr1;
int n;
ll mid;
int col[N],pt[N];
namespace Dp{
	ll dep[4][N];
	void dfsdep(int v,int fa){
		dep[0][v]+=dep[3][v];
		for(int i=tr[1].head[v];i;i=tr[1].nxt[i]){
			int y=tr[1].vis[i];
			if(y==fa) continue;
			dep[3][y]=dep[3][v]+tr[1].c[i];
			dfsdep(y,v);
		}
	}
	void dfsdep2(int v,int fa){
		dep[0][v]+=dep[2][v];
		for(int i=itr.head[v];i;i=itr.nxt[i]){
			int y=itr.vis[i];
			if(y==fa) continue;
			dep[2][y]=dep[2][v]+itr.c[i];
			dfsdep2(y,v);
		}
	}
	int lg[N<<1];ll stt[N<<1][19],d[N];
	int id[N],sign=0,Dep[N]; 
  	void dfs(int x,int fa){
		stt[++sign][0]=d[x],id[x]=sign;
		for(int i=tr[1].head[x];i;i=tr[1].nxt[i]){
			int y=tr[1].vis[i];
  	 		if(y==fa) continue; 
			Dep[y]=Dep[x]+1;
			d[y]=d[x]+tr[1].c[i];
			dfs(y,x);
			stt[++sign][0]=d[x];
		}
	}
	inline void st(){
		for(int i=2;i<=sign;i++) lg[i]=lg[i>>1]+1;
		for(int j=1;j<=18;j++)
			for(int i=1;i+(1<<j)-1<=sign;++i)
				stt[i][j]=min(stt[i][j-1],stt[i+(1<<(j-1))][j-1]);
	}
	inline ll Lca(int x,int y){
		x=id[x],y=id[y];if(x>y) swap(x,y);
		int len=lg[y-x+1];
		return min(stt[x][len],stt[y-(1<<len)+1][len]);
	}
	inline ll lca(int x,int y){return dep[0][x]+dep[0][y]-2*Lca(x,y);}
	struct info{
		int u,v;
		ll L;
		info(){u=v=L=0;}
		info(int _u,int _v):u(_u),v(_v),L(lca(u,v)){}
		inline bool operator < (const info a)const{return L<a.L;}
		friend inline info operator + (info a,info b){
			if(a.u==0) return b;if(b.u==0) return a;
			info res=max(a,b);
			res=max(res,max(info(a.u,b.v),info(a.v,b.u))),res=max(res,max(info(a.u,b.u),info(a.v,b.v)));
			return res;
		}
	}dp[100010][2];
	inline ll calc(const info &a,const info &b){
		if(a.u==0 || b.u==0) return 0;
		return max(max(lca(a.u,b.u),lca(a.v,b.u)),max(lca(a.u,b.v),lca(a.v,b.v)));
	}
	ll ans;
	void dfsdp(int v,int fa){
		for(int i=tr[0].head[v];i;i=tr[0].nxt[i]){
			int y=tr[0].vis[i];
			if(y==fa) continue;
			dfsdp(y,v);
			ans=max(ans,max(calc(dp[v][0],dp[y][1]),calc(dp[v][1],dp[y][0]))+mid-(dep[2][v]<<1));
			dp[v][0]=dp[v][0]+dp[y][0];
			dp[v][1]=dp[v][1]+dp[y][1];
		}
		pt[v]=0;
		tr[0].head[v]=0; 
	}
	inline void work(){dfsdep(1,0);dfs(1,0);st();}
	inline void work1(){dfsdp(1,0);}
}
using Dp::info;
using Dp::dp;
namespace itree{
	int lg[N<<1],stt[N<<1][19],id[N],sign=0,sum=0,dfn[N],dep[N];ll d[N];
  	void dfs(int x,int fa){
		stt[++sign][0]=x,id[x]=sign,dfn[x]=++sum;
		for(int i=itr.head[x];i;i=itr.nxt[i]){
			int y=itr.vis[i];
  	 		if(y==fa) continue; 
			dep[y]=dep[x]+1;
			d[y]=d[x]+itr.c[i];
			dfs(y,x);
			stt[++sign][0]=x;
		}
	}
	inline int mn(int x,int y){return dep[x]<dep[y]?x:y;}
	inline void st(){
		for(int i=2;i<=sign;i++) lg[i]=lg[i>>1]+1;
		for(int j=1;j<=18;j++)
			for(int i=1;i+(1<<j)-1<=sign;++i)
				stt[i][j]=mn(stt[i][j-1],stt[i+(1<<(j-1))][j-1]);
	}
	inline int lca(int x,int y){
		x=id[x],y=id[y];if(x>y) swap(x,y);
		int len=lg[y-x+1];
		return mn(stt[x][len],stt[y-(1<<len)+1][len]);
	}
	int sta[N],tp=0;
	struct Q{
		int x;ll dep;
		Q(){}
		Q(int _x,ll _dep):x(_x),dep(_dep){}
	}a[N<<2];
	inline bool cmp(Q a,Q b){return dfn[a.x]<dfn[b.x];}
	int cnt;
	inline void ins(int x){
		int fa=lca(x,sta[tp]);
		if(!pt[fa]){pt[fa]=1;dp[fa][0]=dp[fa][1]=info(0,0);}
		while(tp>1 && dep[sta[tp-1]]>=dep[fa]){tr[0].add(sta[tp-1],sta[tp]),tp--;}
		if(fa!=sta[tp]) tr[0].add(fa,sta[tp]),sta[tp]=fa;
		sta[++tp]=x;
	}
	inline void build(ll val){
		mid=val;
		for(int i=1;i<=cnt;i++){
			pt[a[i].x]=1;
			dp[a[i].x][col[a[i].x]-1]=info(a[i].x,a[i].x);
			dp[a[i].x][(col[a[i].x]-1)^1]=info(0,0);
			col[a[i].x]=0;
		}
		sort(a+1,a+cnt+1,cmp);
		tr[0].tot=0;
		tp=0;if(a[1].x!=1) sta[++tp]=1;
		for(int i=1;i<=cnt;i++) ins(a[i].x);
		while(tp>1) tr[0].add(sta[tp-1],sta[tp]),--tp;
		Dp::work1();
	}
}
using itree::a;
namespace bfz{
	int rt,siz[N<<2],pt[N<<2],mx,idx;
	vector<pair<int,ll> >vec[N<<2];
	void dfs(int v,int fa){
		for(int i=tr1.head[v];i;i=tr1.nxt[i]){
			int y=tr1.vis[i];
			if(y==fa) continue;
			vec[v].pb(mp(y,tr1.c[i]));
			dfs(y,v);
		}
	}
	inline void rebuild(){
		tr1.tot=1;
		memset(tr1.head,0,sizeof(tr1.head));
		for(int i=1;i<=idx;i++){
			int sz=vec[i].size();
			if(sz<=2) for(int j=0;j<sz;j++) tr1.dbadd(i,vec[i][j].fi,vec[i][j].se);
			else{
				int ls=++idx,rs=++idx;
				tr1.dbadd(i,ls,0);tr1.dbadd(i,rs,0);
				for(int j=0;j<sz;j++){
					if(j&1) vec[ls].pb(mp(vec[i][j].fi,vec[i][j].se));
					else vec[rs].pb(mp(vec[i][j].fi,vec[i][j].se));
				}
			}
		}
	}
	void getroot(int v,int fa,int sum){
		siz[v]=1;
		for(int i=tr1.head[v];i;i=tr1.nxt[i]){
			int y=tr1.vis[i];
			if(!pt[i>>1] && y!=fa){
				getroot(y,v,sum);
				siz[v]+=siz[y];
				int tmp=max(siz[y],sum-siz[y]);
				if(tmp<mx) mx=tmp,rt=i;
			}
		}
	}
	void buildi(int v,int fa,ll dep,int op){
		if(v<=n) col[v]=op,a[++itree::cnt]=itree::Q(v,dep);
		for(int i=tr1.head[v];i;i=tr1.nxt[i]){
			int y=tr1.vis[i];
			if(!pt[i>>1] && y!=fa) buildi(y,v,dep+tr1.c[i],op);
		}
	}
	void solve(int v,int sum){
		mx=INF;getroot(v,0,sum);
		if(mx==INF) return;
		int now=rt;pt[now>>1]=1;
		itree::cnt=0;
		buildi(tr1.vis[now],0,0ll,1);
		buildi(tr1.vis[now^1],0,0ll,2);
		for(int i=1;i<=itree::cnt;i++) Dp::dep[0][a[i].x]+=a[i].dep;
		itree::build(tr1.c[now]);
		for(int i=1;i<=itree::cnt;i++) Dp::dep[0][a[i].x]-=a[i].dep;
		int sz=siz[tr1.vis[now]];
		solve(tr1.vis[now],sz);
		solve(tr1.vis[now^1],sum-sz);
	}
}
int main(){ int size=400<<20;//40M
    //__asm__ ("movl  %0, %%esp\n"::"r"((char*)malloc(size)+size));//调试用这个 
    __asm__ ("movq %0,%%rsp\n"::"r"((char*)malloc(size)+size));//提交用这个 

    //main函数代码 
	n=read();bfz::idx=n;
	for(int x,y,i=1;i<n;i++){
		x=read(),y=read();ll z=read();
		tr1.dbadd(x,y,z);
	}
	for(int x,y,i=1;i<n;i++){
		x=read(),y=read();ll z=read();
		itr.dbadd(x,y,z);
	}
	for(int x,y,i=1;i<n;i++){
		x=read(),y=read();ll z=read();
		tr[1].dbadd(x,y,z);
	}
	itree::dfs(1,0);itree::st();
	Dp::dfsdep2(1,0);Dp::work();
	bfz::dfs(1,0);bfz::rebuild();bfz::solve(1,bfz::idx);
	cout<<Dp::ans;
	exit(0);
	return 0;
}

你可能感兴趣的:(边分治,树形DP,虚树)