Tree Cutting (Easy Version)
You are given an undirected tree of n vertices.
Some vertices are colored blue, some are colored red and some are uncolored. It is guaranteed that the tree contains at least one red vertex and at least one blue vertex.
You choose an edge and remove it from the tree. Tree falls apart into two connected components. Let’s call an edge nice if neither of the resulting components contain vertices of both red and blue colors.
How many nice edges are there in the given tree?
Input
The first line contains a single integer n (2≤n≤3⋅105) — the number of vertices in the tree.
The second line contains n integers a1,a2,…,an (0≤ai≤2) — the colors of the vertices. ai=1 means that vertex i is colored red, ai=2 means that vertex i is colored blue and ai=0 means that vertex i is uncolored.
The i-th of the next n−1 lines contains two integers vi and ui (1≤vi,ui≤n, vi≠ui) — the edges of the tree. It is guaranteed that the given edges form a tree. It is guaranteed that the tree contains at least one red vertex and at least one blue vertex.
Output
Print a single integer — the number of nice edges in the given tree.
Examples
Input
5
2 0 0 1 2
1 2
2 3
2 4
2 5
Output
1
题意:有n个点和n-1条边,每个点有一个颜色,1表示红色,2表示蓝色,0没有颜色。现在问你是否可以删除一条边,使得分成的两部分中一部分包含全部红色点,另一部分包含全部黑色点,输出共有多少条这样的边。
题解:首先dfs预处理出每个点的子树中包含红色和蓝色点的个数,这样如果一条边的某个端点包含全部蓝点或者包含全部红点,那这条边就可以删。
#include
#include
using namespace std;
const int maxn=600000+10;
int s[maxn];
int R=0,B=0;
struct cc{
int from,to;
}es[maxn];
int first[maxn],nxt[maxn];
int tot=0;
void build(int ff,int tt)
{
es[++tot]=(cc){ff,tt};
nxt[tot]=first[ff];
first[ff]=tot;
}
struct ff{
int R,B;
}num[maxn];
int ans=0;
void dfs(int x,int fa)
{
for(int i=first[x];i;i=nxt[i])
{
int v=es[i].to;
if(v==fa) continue;
dfs(v,x);
num[x].R+=num[v].R;
num[x].B+=num[v].B;
}
}
void dfss(int x,int fa)
{
for(int i=first[x];i;i=nxt[i])
{
int v=es[i].to;
if(v==fa) continue;
if((num[v].R==R&&num[v].B==0)||(num[v].R==0&&num[v].B==B))
{
ans++;
}
dfss(v,x);
}
}
int main()
{
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
scanf("%d",&s[i]);
if(s[i]==1)
{
R++;
num[i].R++;
}
if(s[i]==2)
{
B++;
num[i].B++;
}
}
for(int i=1;i