【BZOJ】【P3697】【采药人的路径】【题解】【点分治】

传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=3697

阳视作1 阴视作-1 统计路径为0且祖先有和自己权值相同的个数

Code:

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<vector>
#include<cctype>
#include<map>
#include<set>
using namespace std;
typedef long long LL;
const int maxn=1e5+5;
LL ans=0;
int n;
int getint(){
	int res=0;char c=getchar();
	while(!isdigit(c))c=getchar();
	while(isdigit(c))res=res*10+c-'0',c=getchar();
	return res;
}
struct edge{int u,v,w;};
vector<edge>G[maxn];
int siz[maxn],f[maxn],dep[maxn],cant[maxn],root,All,d[maxn];
map<int,int>M,mp,MA,data,data2;
void makert(int u,int fa){
	siz[u]=1;f[u]=0;
	for(int i=0;i<G[u].size();i++){
		edge e=G[u][i];
		if(e.v!=fa&&!cant[e.v]){
			dep[e.v]=dep[u]+1;
			makert(e.v,u);
			siz[u]+=siz[e.v];
			f[u]=max(f[u],siz[e.v]);
		}
	}f[u]=max(f[u],All-f[u]);
	if(f[root]>f[u])root=u;
}
void dfs(int u,int fa){
	data[d[u]]++;
	for(int i=0;i<G[u].size();i++){
		edge e=G[u][i];
		if(e.v==fa||cant[e.v])continue;
		d[e.v]=d[u]+e.w;
		if(M.count(d[e.v])){
			data2[d[e.v]]++;
//			cerr<<d[e.v]<<endl;
		}
		M[d[e.v]]++;
		dfs(e.v,u);
		if(!--M[d[e.v]])M.erase(d[e.v]);
	}
}
typedef map<int,int>::iterator iter;
void deb(map<int,int>M){
	puts("");
	for(iter it=M.begin();it!=M.end();it++)if(it->second)
	cout<<it->first<<" "<<it->second<<endl;
	
}
void calc(int u){
	LL res=0;MA.clear();mp.clear();d[u]=0;
	for(int i=0;i<G[u].size();i++){
		edge e=G[u][i];
		if(cant[e.v])continue;
		d[e.v]=e.w;M[e.w]++;data.clear();data2.clear();
		dfs(e.v,u);M.erase(e.w);
//		deb(MA);
//		deb(mp);
//		deb(data);
//		deb(data2);
		for(iter it=data.begin();it!=data.end();it++){
			LL num=it->first,cant=it->second-data2[it->first],can=data2[it->first];
			if(!num){
				res+=it->second*MA[0];
			}else{
				res+=can*MA[-num];
				res+=cant*mp[-num];				
			}
		}
		for(iter it=data.begin();it!=data.end();it++)MA[it->first]+=it->second;
		for(iter it=data2.begin();it!=data2.end();it++)mp[it->first]+=it->second;
	}ans+=res;ans+=mp[0];
}
void solve(int u){
	calc(u);cant[u]=1;
	for(int i=0;i<G[u].size();i++){
		edge e=G[u][i];
		if(cant[e.v])continue;
		All=siz[e.v];
		f[root=0]=n+1;
		makert(e.v,0);
		solve(root);
	}
}
int main(){
	n=getint();All=n;
	for(int i=1;i<n;i++){
		int u=getint(),v=getint(),w=getint();
		G[u].push_back((edge){u,v,w?w:-1});
		G[v].push_back((edge){v,u,w?w:-1});
	}f[root=0]=n+1;
	makert(1,1);
	solve(root);
	cout<<ans<<endl;
	return 0;
}


你可能感兴趣的:(bzoj)