之前没怎么做过这类题目,好DIAO的样子,回头总结!!!
下面是当时比赛时候WYP敲的代码,当时我都没看。。。
#include
using namespace std;
#define maxn 100100
#define maxm 200005
#define rd(x) scanf("%d", &x)
#define rd2(x, y) scanf("%d%d", &x, &y)
#define mod 1000000007
int V[maxn];
struct Edge{
int next,v;
}edge[maxn*4];
int head[maxn],vis[maxn], tot;
int n,k;
long long int res;
int f[15], sz[maxn], mx[maxn];
void addedge(int u, int v){
edge[tot].v = v; edge[tot].next = head[u]; head[u] = tot++;
}
map mp[maxn];
void init(){
tot = 0;
memset(head, -1, sizeof(head));
memset(vis, 0, sizeof(vis));
for(int i = 0; i < 10; i++){
f[i+1] = 1 << i;
}
for(int i = 1; i <= n ; i++){
mp[i].clear();
}
}
///统计子树大小
void dfsize(int x, int fa){
sz[x] = 1; mx[x] = 0;
for(int i = head[x]; i != -1; i = edge[i].next){
int v = edge[i].v;
if(v == fa || vis[v]) continue;
dfsize(v, x);
sz[x] += sz[v];
if(sz[v] > mx[v]) mx[x] = sz[v];
}
}
void dfsroot(int r, int x, int fa, int &root, int &mm){
if(sz[r] - sz[x] > mx[x]) mx[x] = sz[r] - sz[x];
if(mx[x] < mm) mm = mx[x], root = x;
for(int i = head[x]; i != -1; i = edge[i].next){
int v = edge[i].v;
if(v == fa || vis[v]) continue;
dfsroot(r, v, x, root, mm);
}
}
vector vec;
void cal(int x, int fa, int kk, int root){ // 计算当前子树中合法的点对数
int fx = f[V[x]];
kk = (kk | fx);
int k2 = (1 << k) -1 - kk;
int ks = (1 << k) -1;
if(x != root) {
for(map::iterator it = mp[root].begin(); it != mp[root].end(); it++){
int k3 = (*it).first;
if((k3 & k2) == k2) res += (*it).second;
}
vec.push_back(kk);
}
for(int i = head[x]; i != -1; i = edge[i].next){
int v = edge[i].v;
if(v == fa || vis[v]) continue;
cal(v, x, kk, root);
if(x == root){
int k2 = (1 << k) -1;
for(int i =0 ; i < vec.size(); i++){
mp[root][vec[i]]++;
//if(vec[i] == k2) res++;
}
vec.clear();
}
}
}
void solve(int x){
dfsize(x, x);
int mm = n + 1;
int root;
dfsroot(x, x, x, root, mm);
vec.clear();
mp[root][0]++;
cal(root, root, 0, root);
vis[root] = 1;
for(int i = head[root]; i != -1; i = edge[i].next){
int v = edge[i].v;
if(vis[v]) continue;
solve(v);
}
}
int main()
{
int a, b;
while(~scanf("%d%d", &n, &k)){
for(int i =1; i <= n; i++){
rd(V[i]);
}
init();
res = 0;
for(int i =1; i < n; i++){
rd2(a, b);
addedge(a, b);
addedge(b, a);
}
if(k == 1){
printf("%lld\n", 1LL*n*n);
continue;
}
solve(1);
res = res*2;
if(k == 1) res += n;
printf("%lld\n", res);
}
return 0;
}