题目大意:
就是一个有N个点的树现在要从这棵树上选出K个点使得这K个点两两之间没有祖先关系, 即任意一个都不是另外一个的祖先, 那么选出K个这样的点能得到的最大的权值和是多少
大致思路:
首先这题如果K不大的话可以用树形DP直接弄, 不过这个题K比较大, 于是需要用一个巧妙的方法
先膜拜一下vawait....
好了进入正题
首先处理出每个节点u, 和其子树中所有节点的权值中的最大值w[u]
然后处理出对于每个节点u, 从其子树中选出两个没有祖先关系的点的权值和的最大值(不包括节点u自己) f[u]
那么用优先队列维护一下,
初始的时候加入(w[root], root)
然后每次选择当前最大的出队(w[u], u), 如果是用过的(即u及其子树中选择了最大值w[u])那么就出队, 取消这个点的选择状态, 将所有u的儿子(w[u], u)入队
否则就选择这个点, 然后将(f[u] - w[u], u)入队
这样保证了当u有两个子节点权值和比w[u](u及其子树中任意一个结点权值的最大值)大的时候, 会放弃选择u而转向选择u的子节点
代码如下:
Result : Accepted Memory : 7428 KB Time : 608 ms
/* * Author: Gatevin * Created Time: 2015/8/10 20:10:18 * File Name: Sakura_Chiyo.cpp */ #include<iostream> #include<sstream> #include<fstream> #include<vector> #include<list> #include<deque> #include<queue> #include<stack> #include<map> #include<set> #include<bitset> #include<algorithm> #include<cstdio> #include<cstdlib> #include<cstring> #include<cctype> #include<cmath> #include<ctime> #include<iomanip> using namespace std; const double eps(1e-8); typedef long long lint; #define maxn 100010 int N, K, root; int w[maxn]; bool select[maxn]; vector<int> G[maxn]; int f[maxn];//f[u]表示从u的子孙节点中选两个权值和的最大值 void dfs(int now)//处理出每个子树u及其子树中最小的权值w[u], 以及其子树中两个没有关联的父亲关系的点权值最大和 { int nex; int mx = -1e9;//表示前几个子树中的最小w for(int i = 0, sz = G[now].size(); i < sz; i++) { nex = G[now][i]; dfs(G[now][i]); f[now] = max(f[now], max(f[nex], w[nex] + mx)); w[now] = max(w[now], w[nex]); mx = max(mx, w[nex]); } return; } void solve() { memset(select, 0, sizeof(select)); priority_queue<pair<int, int> > Q; Q.push(make_pair(w[root], root)); int ret = 0; while(K) { if(Q.empty())//选不出K个 { puts("0"); return; } int u = Q.top().second; Q.pop(); if(select[u])//说明现在选择u的子节点中的两个要更优 { K++; select[u] = 0;//放弃选择子树的根节点u ret -= w[u]; for(int i = 0, sz = G[u].size(); i < sz; i++) Q.push(make_pair(w[G[u][i]], G[u][i]));//加入其儿子节点 } else { K--; select[u] = 1; ret += w[u];//选择了u结点 Q.push(make_pair(f[u] - w[u], u));//那么当选择两个其子节点和和它的差值要加入队列, 以确定是否要放弃这个结点 } } printf("%d\n", ret); return; } int main() { while(scanf("%d %d", &N, &K), N || K) { for(int i = 1; i <= N; i++) G[i].clear(); int p; for(int i = 1; i <= N; i++) { f[i] = -1e9; scanf("%d %d", &p, w + i); if(p == 0) root = i; else G[p].push_back(i); } dfs(root); solve(); } return 0; }