Power Station POJ 4045

题意:给你一棵树,让你求一点,使该点到其余各点的距离之和最小。如果这样的点有多个,则按升序依次输出。

树型dp

#include <cstdio>

#include <cstring>

#include <vector>

#include <set>

using namespace std;

const int maxn=50010;

typedef __int64 LL;

vector<int>tree[maxn];// to save the relation

LL f[maxn],g[maxn],dp[maxn];//f[u] u as root to all his sons' distance g[u]the number of u'sons

set<int> myqueue;

void dfs(int u,int pa){

/* 以1为根,所有的子树到他们的子节点的和 */

    if(tree[u].size() == 1 && u != 1){

        g[u]=1;

        f[u]=0;

        return;

    }

    for(int v=0;v < tree[u].size();v++){

        if(tree[u][v]!=pa){

            dfs(tree[u][v],u);

            g[u]+=g[tree[u][v]];//the son's sons' number

            f[u]+=f[tree[u][v]]+g[tree[u][v]];

        }

    }

    g[u]++;// himself

}

void dfs2(int u,int pa){// to sum the way from his father

//再考虑从父节点来的

    if(tree[u].size()==1 && u!=1){

        dp[u]=dp[pa]+g[1]-(g[u]<<1);

        return;

    }

    for(int v=0;v<tree[u].size();v++){

        if(tree[u][v]!= pa){

//到某节点的距离的和=该节点子树的距离和(在dfs1中获得)+从父亲那一支子树获得的和(此时,父节点那一支看成子树)。dp[儿子]=dp[父节点]-dp[儿子]-g[儿子](儿子到父节点这条路被减了子树的子节点数的次数)    + dp[儿子]+(g[1]-g[儿子])(父树上的所有节点)

            dp[tree[u][v]]=dp[u]+g[1]-(g[tree[u][v]]<<1);

            dfs2(tree[u][v],u);//先算了之后再跑子树

        }

    }

}

int main(){

    int t;

    int n,I,R,a,b;

    scanf("%d",&t);

    LL mmin;

    while(t--){

        scanf("%d%d%d",&n,&I,&R);



        for(int i=0;i<=n;i++){

            tree[i].clear();

        }

        for(int i=2;i<=n;i++){

            scanf("%d%d",&a,&b);

            tree[a].push_back(b);

            tree[b].push_back(a);

        }

//        for(int i=0;i<=n;i++){

//            for(int j=0;j<tree[i].size();j++){

//                printf("%d ",tree[i][j]);

//            }

//            printf("\n");

//        }

        memset(f,0,sizeof(f));

        memset(g,0,sizeof(g));

        dfs(1,-1);

        dp[1]=f[1];

        dfs2(1,-1);

        mmin=dp[1];

        myqueue.clear();

        myqueue.insert(1);

        for(int i=2;i<=n;i++){

            if(dp[i]<mmin){

                mmin=dp[i];

                myqueue.clear();

                myqueue.insert(i);

            }

            if(dp[i]==mmin) myqueue.insert(i);

        }

        //warning : output long long should be I64d

        printf("%I64d\n",I*I*R*mmin);//use set not num but the op

            for(set<int>::iterator it=myqueue.begin();it!=myqueue.end();++it){

                printf("%d ",*it);

            }

        printf("\n\n");

    }

}

  

你可能感兴趣的:(poj)