CCPC-Wannafly Winter Camp Day5 (Div1, onsite)(Nested Tree-虚树)

你有一棵n个点树T,然后你把它复制了m遍,然后在这m棵树之间又加了m−1条边,变成了一棵新的有nm个点的树T_2。求T_2中所有点对的距离和,由于答案很大,对10^9+7取模。
1 ≤ n ≤ 1 0 5 1\le n \le 10^5 1n105

dp 套 dp
预处理 d p s u m [ x ] dpsum[x] dpsum[x],表示所有1棵n个点的树中,所有点到点 x x x的距离和。

在m个点连成的新树中,每个拷贝和父拷贝连接的点为 t o p x top_x topx,子树大小 t o t x tot_x totx,子树中所有点到 t o p x top_x topx的距离和 d p s x dps_x dpsx.
对每个拷贝,统计所有点对的lca在该拷贝的路径对答案的贡献
子树的拷贝到该拷贝 a n s + = d e p s u m a ∗ t o t v + n ∗ d e p s v + n ∗ t o t v ans+=depsum_a*tot_v+n*deps_v+n*tot_v ans+=depsumatotv+ndepsv+ntotv
子树的拷贝到子树的拷贝 a n s + = d e p s v 2 ∗ t o t v + d e p s v ∗ t o t v 2 + t o t v ∗ t o t v 2 ∗ ( 2 + d i s ( a , a 2 ) ) ans+=deps_{v2}*tot_v+deps_{v}*tot_{v2}+tot_v*tot_{v2}*(2+dis(a,a_2)) ans+=depsv2totv+depsvtotv2+totvtotv2(2+dis(a,a2))
其中a和a2分别表示子树v和子树v2连出去的点
最后为了计算 d i s ( a , a 2 ) dis(a,a_2) dis(a,a2)需要用虚树进行处理。
处理出所有关键节点,并把所有连出子树v的位置a上以 t o t v tot_v totv为点权, s t o t x stot_x stotx为x的子树的点权和。
设all为虚树的点权和,每条虚树上的边(x,v)的贡献
a n s + = ( d e p v − d e p x ) ∗ s t o t v ∗ ( a l l − s t o t v ) ans+=(dep_v-dep_x)*stot_v*(all-stot_v) ans+=(depvdepx)stotv(allstotv),
最后更新deps,虚树上的边(x,v)对deps的贡献为
( d e p [ v ] − d e p [ x ] ) ∗ ( ( t o p x 在 v 的 子 树 ) ? ( a l l − s t o t x ) : s t o t x ) (dep[v]-dep[x])*((top_x在v的子树)?(all-stot_x):stot_x) (dep[v]dep[x])((topxv)?(allstotx):stotx)
每份拷贝v对 d e p s dep_s deps的贡献为
d e p s v + t o t v deps_v+tot_v depsv+totv

#include 
using namespace std;
#define For(i,n) for(int i=1;i<=n;i++)
#define Fork(i,k,n) for(int i=k;i<=n;i++)
#define ForkD(i,k,n) for(int i=n;i>=k;i--)
#define Rep(i,n) for(int i=0;i
#define ForD(i,n) for(int i=n;i>0;i--)
#define RepD(i,n) for(int i=n;i>=0;i--)
#define Forp(x) for(int p=pre[x];p;p=next[p])
#define Forpiter(x) for(int &p=iter[x];p;p=next[p])  
#define Lson (o<<1)
#define Rson ((o<<1)+1)
#define MEM(a) memset(a,0,sizeof(a));
#define MEMI(a) memset(a,0x3f,sizeof(a));
#define MEMi(a) memset(a,128,sizeof(a));
#define MEMx(a,b) memset(a,b,sizeof(a));
#define INF (0x3f3f3f3f)
#define F (1000000007)
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define vi vector 
#define pi pair
#define SI(a) ((a).size())
#define Pr(kcase,ans) printf("Case #%d: %lld\n",kcase,ans);
#define PRi(a,n) For(i,n-1) cout<
#define PRi2D(a,n,m) For(i,n) { \
						For(j,m-1) cout<
#pragma comment(linker, "/STACK:102400000,102400000")
#define ALL(x) (x).begin(),(x).end()
#define gmax(a,b) a=max(a,b);
#define gmin(a,b) a=min(a,b);
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
ll mul(ll a,ll b){return (a*b)%F;}
ll add(ll a,ll b){return (a+b)%F;}
ll sub(ll a,ll b){return ((a-b)%F+F)%F;}
void upd(ll &a,ll b){a=(a%F+b%F)%F;}
inline int read()
{
	int x=0,f=1; char ch=getchar();
	while(!isdigit(ch)) {if (ch=='-') f=-1; ch=getchar();}
	while(isdigit(ch)) { x=x*10+ch-'0'; ch=getchar();}
	return x*f;
} 
namespace VT{
	#define LL long long
	int cnt;
	int f[300001], g[300001];
	unsigned char lg2[1 << 20];
	int st[20][600001], len;
	LL h[300001];
	#ifdef EDGE
	vector<pair<int, int>> v[300001];
	void dfs(int i, int father)
	{
		int rank = st[0][len] = ++cnt;
		f[cnt] = i; g[i] = len++;
		for (unsigned int j = 0; j < v[i].size(); j++){
			int k = v[i][j].first;
			if (k != father){
				h[k] = h[i] + v[i][j].second;
				dfs(k, i);
				st[0][len++] = rank;
			}
		}
	}
	#else
	vector<int> v[300001];
	void dfs(int i, int father)
	{
		int rank = st[0][len] = ++cnt;
		f[cnt] = i; g[i] = len++;
		for (unsigned int j = 0; j < v[i].size(); j++){
			int k = v[i][j];
			if (k != father){
				h[k] = h[i] + 1;
				dfs(k, i);
				st[0][len++] = rank;
			}
		}
	}
	#endif
	void init(int root)
	{
		len = 0; cnt = 0;
	#ifdef EDGE
		v[0].push_back(make_pair(root, 0));
	#else
		v[0].push_back(root);
	#endif
		dfs(0, 0);
		for (unsigned char i = 0; (1 << i) <= len; i++)
			memset(&lg2[1 << i], i, 1 << i);
		for (int i = 1; i <= lg2[len]; i++){
			for (int k = len - (1 << i); k >= 0; k--)
				st[i][k] = min(st[i - 1][k], st[i - 1][k + (1 << (i - 1))]);
		}
	}
	inline int query(int i, int j){
		int pos1 = g[i], pos2 = g[j];
		if (pos1 > pos2)swap(pos1, pos2);
		int t = lg2[pos2 - pos1];
		return f[min(st[t][pos1], st[t][pos2 - (1 << t) + 1])];
	}
	int node[300001];
	int s[300001], top;
	vector<int> vt[300001];
	bool used[300001];//±£´æÐéÊ÷ÖÐÄÄЩµãÊÇÔ­ÊäÈëµã
	inline bool cmp(int i, int j){ return g[i] < g[j]; }
	void makeVirtualTree(int n)
	{
		sort(node, node + n, cmp);
		s[top = 0] = 0; node[n] = 0;
		for (int i = 0; i <= n; i++){
			int t = query(s[top], node[i]);
			if (t != s[top]){
				while (query(s[--top], t) != s[top]){
					vt[s[top]].push_back(s[top + 1]);
				}
				vt[t].push_back(s[top + 1]);
				if (s[top] != t)s[++top] = t;
			}
			s[++top] = node[i];
			used[node[i]] = 1;
		}
	}
	void clearVirtualTree(int i)
	{
		for (unsigned int j = 0; j < vt[i].size(); j++)
			clearVirtualTree(vt[i][j]);
		vt[i].clear(); used[i] = 0;
	}
	
}

int n,m;
#define MAXN (112345)
vi e[MAXN];
ll dfn[MAXN]={},totT=0,rr[MAXN]={};
ll dep[MAXN]={},depsum[MAXN]={},depup[MAXN]={},depdown[MAXN]={},sz[MAXN]={};
void dfs(int x,int fa) {
	dep[x]=dep[fa]+1;
	dfn[x]=++totT;
	sz[x]=1; 
	for(int v:e[x]) if(v^fa) {
		dfs(v,x); 
		sz[x]+=sz[v];
		depdown[x]+=depdown[v]+sz[v];
		depdown[x]%=F;
	}
	rr[x]=totT;
}
void dfs2(int x,int fa) {
	if(x^1) depup[x]=depup[fa]+depdown[fa]-depdown[x]-sz[x] +n-sz[x];
	depup[x]%=F;
	for(int v:e[x]) if(v^fa) {
		dfs2(v,x); 
	}
}
struct node{
	int b,u,v;
};
vector<node> e2[MAXN];
ll tot[MAXN]={},top[MAXN]={};
ll ans=0;
void dfs3(int x,int fa) {
	tot[x]=n;
	for(auto pa:e2[x]) {
		int v=pa.b,a=pa.u,b=pa.v;
		if(v^fa) {
			top[v]=b;
			dfs3(v,x);
			tot[x]=(tot[x]+tot[v])%F;
		}
	}
}
ll deps[MAXN]={},stot[MAXN]={};
ll all=0,_deps=0;
int _topx=0;
void dfs5(int x,int fa) {
	for(auto v:VT::vt[x]) if(v^fa){
		dfs5(v,x);
		upd(stot[x],stot[v]);
	}
}
void dfs6(int x,int fa) {
	for(auto v:VT::vt[x]) if(v^fa){
		ll w=dep[v]-dep[x];
		dfs6(v,x);
		upd(ans,mul( mul(stot[v], sub(all,stot[v])), w) );
	}
	for(auto v:VT::vt[x]) if(v^fa){
		ll w=dep[v]-dep[x],p=0;
		if(dfn[v]<=dfn[_topx]&& dfn[_topx] <=rr[v] ) p=sub(all,stot[v])*w%F;
		else p=stot[v]*w%F;
		upd(_deps,p);
	}
}
void dfs_clr(int x,int fa) {
	stot[x]=0;
	for(auto v:VT::vt[x]) if(v^fa){
		dfs_clr(v,x);
	}
}

void dfs4(int x,int fa) {
	ll totv2=0,depsv2=0;
	for(auto pa:e2[x]) {
		int v=pa.b,a=pa.u,b=pa.v;
		if(v^fa) {
			dfs4(v,x);
			
			upd(ans,depsum[a]*tot[v]%F+ deps[v]*n%F + n*tot[v]%F );
			upd(ans,depsv2*tot[v]%F+deps[v]*totv2%F + tot[v]*totv2%F*2%F);
			upd(totv2,tot[v]);
			upd(depsv2,deps[v]);
		}
	}
	{
		int sn=0;
		if(top[x]) VT::node[sn++]=top[x];
		for(auto pa:e2[x]) {
			int v=pa.b,a=pa.u,b=pa.v;
			if(v^fa) {
				VT::node[sn++]=a;
			}	
		}
		sort(VT::node,VT::node+sn);
		sn=unique(VT::node,VT::node+sn)-VT::node;
		Rep(i,sn) stot[VT::node[i]]=0;
		for(auto pa:e2[x]) {
			int v=pa.b,a=pa.u,b=pa.v;
			if(v^fa) {
				stot[a]+=tot[v];
			}	
		}
		VT::makeVirtualTree(sn);
		dfs5(0,-1);
		all=stot[0]; _topx=top[x]; _deps=0;
		dfs6(VT::vt[0][0],-1);
		upd(deps[x],_deps);
		upd(deps[x],depsum[top[x]]);
		for(auto pa:e2[x]) {
			int v=pa.b,a=pa.u,b=pa.v;
			if(v^fa) {
				upd(deps[x],deps[v]+tot[v]);
			}
		}
		dfs_clr(0,-1);
		VT::clearVirtualTree(0);
	}
	
}
ll pow2(ll a,int b,ll p)  //a^b mod p 
{  
    if (b==0) return 1%p;  
    if (b==1) return a%p;  
    ll c=pow2(a,b/2,p)%p;  
    c=c*c%p;  
    if (b&1) c=c*a%p;  
    return c%p;  
}  
ll inv(ll a,ll p) { //gcd(a,p)=1
	return pow2(a,p-2,p);
}
ll inv2=inv(2,F);

int main()
{
//	freopen("H.in","r",stdin);
//	freopen(".out","w",stdout);
	cin>>n>>m;
	For(i,n-1) {
		int u=read(),v=read();
		e[u].pb(v);e[v].pb(u);
		VT::v[u].pb(v);
		VT::v[v].pb(u);
	}	
	VT::init(1);
	dfs(1,0);
	dfs2(1,0);
	For(x,n) depsum[x]=(depdown[x]+depup[x])%F;
	For(i,m-1) {
		int a=read(),b=read(),u=read(),v=read();
		e2[a].pb(node{b,u,v}),	e2[b].pb(node{a,v,u});
	}
	dfs3(1,0);
	dfs4(1,0);
	ll ps=0;
	For(x,n) upd(ps,depsum[x]*m%F);
	ps=ps*inv2%F;
	upd(ans,ps);
	cout<<ans<<endl;
	return 0;
}

你可能感兴趣的:(虚树,dp套dp)