2019中国大学生程序设计竞赛 Checkout

 暴力枚举即可

#include 
using namespace std;
typedef long long LL;
typedef LL lint;
const lint inf = 0x3f3f3f3f;
const lint maxn = 100005;
const lint maxm = 200005;
lint a[maxn],tot,he[maxn],ne[maxm],ver[maxm],sum,n;
void add( lint x,lint y ){
    ver[++tot] = y;
    ne[tot] = he[x];
    he[x] = tot;
}
vector ve[maxn];
void dfs_pre( lint x,lint f ){
    for( lint cure = he[x];cure;cure = ne[cure] ){
        lint y = ver[cure];
        if( y == f ) continue;
        ve[x].push_back(a[y]);
        if( a[x] == a[y] ) {
            sum++;
        }
        dfs_pre(y,x);
    }
    sort( ve[x].begin(),ve[x].end() );
}
void solve_pre( lint x ){
    lint pre = 1;
    lint re = 0;
    for( lint i = 1;i < ve[x].size();i++ ){
        if( ve[x][i] != ve[x][i-1] ) {
            re += pre*(pre-1)/2;
            pre=1;
        }
        else pre++;
    }
    re += pre*(pre-1)/2;
    sum +=re;
}
void pre_work(){
    dfs_pre(1,0);
    for( lint i = 1;i <= n;i++ ){
        solve_pre(i);
    }
}
lint ans[maxn];
lint dfs_solve( lint x,lint y ){
    lint cur = 0,px = 0,py = 0;
    lint res = 0;
    while( px < ve[x].size() ){
        if( x >= 1 && ve[x][px]==ve[x][px-1] ){
            res += cur;
        }else {
            cur = 0;
            while( py < ve[y].size() && ve[y][py] < ve[x][px] ) py++;
            while (py < ve[y].size() && ve[y][py] == ve[x][px]) {
                cur++;
                py++;
            }
            res += cur;
        }
        px++;
    }
    return res;
}
lint dfs_solve2( lint c,lint x ){
    lint p1 = lower_bound( ve[x].begin(),ve[x].end(),c )-ve[x].begin();
    lint p2 = upper_bound( ve[x].begin(),ve[x].end(),c )-ve[x].begin();
    return p2-p1;
}
void dfs( lint x,lint f ){
    ans[x]=sum;
    if( a[x] == a[f] ) ans[x]--;
    //ans[x]-=res[x];
    ans[x] -= dfs_solve2( a[x],x );
    if( x != 1 )
    ans[x] -= dfs_solve2( a[x],f )-1;
    lint mm = -inf;
    bool flag  = false;
    for( lint cure = he[x];cure;cure = ne[cure] ){
        lint y = ver[cure];
        if( y == f )continue;
        flag = true;
        lint cur = 0;
        if( a[y]==a[f] ) cur++;
        cur += dfs_solve( x,y );
        cur -= dfs_solve2(a[y],y);
        cur += dfs_solve2( a[y],f ) - (( a[y] == a[x] )?1:0) ;
        mm = max( mm,cur );
    }
    if(!flag) mm = 0;
    ans[x] += mm;
    for( lint cure = he[x];cure;cure = ne[cure] ){
        lint y= ver[cure];
        if( y == f ) continue;
        dfs(y,x);
    }
}
int main(){
    lint m;
    tot = 1;
    scanf("%lld%lld",&n,&m);
    for( lint i = 1;i <= n;i++ ) scanf("%lld",&a[i]);
    for( lint x,y,i = 1;i <= n-1;i++ ){
        scanf("%lld%lld",&x,&y);
        add(x,y);add(y,x);
    }
    pre_work();
    dfs(1,0);
    printf("%lld",ans[1]);
    for( lint i = 2;i <= n;i++ ){
        printf(" %lld",ans[i]);
    }
    cout << endl;
    return 0;
}

 

你可能感兴趣的:(2019中国大学生程序设计竞赛 Checkout)