传送门
有一棵树, n n n 个节点,每条边有边权 c i c_i ci,每个点的颜色为 R,G,B 三种颜色之一(R 为红色,G 为绿色,B 为蓝色)。
你要统计有序对 ( U , V ) (U,V) (U,V) 的数量,其中 U , V U,V U,V 是满足以下条件的两个点集:
答案对 1 0 9 + 7 10^9+7 109+7 取模。
数据范围: n ≤ 2000 n≤2000 n≤2000, w ≤ 1 0 7 w≤10^7 w≤107, c i ≤ 1 0 6 c_i≤10^6 ci≤106。
听说这是一道论文题。
我们先考虑只有一个绿点的时候怎么做。
考虑把绿点当作根,然后树形 d p dp dp。设 R ( x ) \texttt R(x) R(x) 表示强制选 x x x 时,以 x x x 为根的子树中只有红点和绿点的连通块个数。 B ( x ) \mathtt B(x) B(x) 的定义与之类似:
R ( x ) = ∏ v ∈ s o n ( x ) ( R ( v ) + 1 ) B ( x ) = ∏ v ∈ s o n ( x ) ( B ( v ) + 1 ) \begin{aligned} \mathtt R(x)&=\prod_{v\in son(x)}(\texttt R(v)+1)\\ \mathtt B(x)&=\prod_{v\in son(x)}(\texttt B(v)+1) \end{aligned} R(x)B(x)=v∈son(x)∏(R(v)+1)=v∈son(x)∏(B(v)+1)
意思就是, x x x 的儿子 v v v 都有 R ( v ) + 1 \texttt R(v)+1 R(v)+1 种选择( + 1 +1 +1 是因为可以不选),由于强制选 x x x,用乘法原理乘一下就是答案。
那么答案就是 R ( r ) × B ( r ) \texttt R(r)\times \texttt B(r) R(r)×B(r), r r r 就是那个作为根的绿点。
那么如果有几个绿点是连通的,答案会被算重,考虑如果减去重复的部分。
这时想到,对于连通的一些绿点,点数减去边数等于 1 1 1。因此我们可以先枚举点加上,再枚举边减去。
时间复杂度 O ( n 2 ) O(n^2) O(n2)。
#include
using namespace std;
const int N=4005,P=1e9+7;
int n,lim,t,ans,col[N];
int first[N],v[N],w[N],nxt[N];
struct edges{int u,v,w;}e[N];
int add(int x,int y) {return x+y>=P?x+y-P:x+y;}
int dec(int x,int y) {return x-y< 0?x-y+P:x-y;}
int mul(int x,int y) {return 1ll*x*y%P;}
void edge(int x,int y,int z){
nxt[++t]=first[x],first[x]=t,v[t]=y,w[t]=z;
}
int R[N],B[N];
void dfs(int x,int fa,int dis){
if(dis>lim) {R[x]=B[x]=0;return;}
R[x]=(col[x]!=3),B[x]=(col[x]!=1);
for(int i=first[x];i;i=nxt[i]){
int to=v[i];
if(to==fa) continue;
dfs(to,x,dis+w[i]);
R[x]=mul(R[x],R[to]+1);
B[x]=mul(B[x],B[to]+1);
}
}
void calc(int x){
dfs(x,0,0);
ans=add(ans,mul(R[x],B[x]));
}
void calc(int u,int v,int w){
dfs(u,v,w),dfs(v,u,w);
ans=dec(ans,mul(mul(R[u],B[u]),mul(R[v],B[v])));
}
char S[N];
int main(){
scanf("%d%d%s",&n,&lim,S+1);
for(int i=1;i<=n;++i){
if(S[i]=='R') col[i]=1;
if(S[i]=='G') col[i]=2;
if(S[i]=='B') col[i]=3;
}
for(int i=1,x,y,z;i<n;++i){
scanf("%d%d%d",&x,&y,&z);
edge(x,y,z),edge(y,x,z),e[i]=(edges){x,y,z};
}
for(int i=1;i<=n;++i) if(col[i]==2) calc(i);
for(int i=1;i<n;++i){
int u=e[i].u,v=e[i].v,w=e[i].w;
if(col[u]==2&&col[v]==2) calc(u,v,w);
}
printf("%d\n",ans);
return 0;
}