spoj 1825
题目:http://www.spoj.com/problems/FTOUR2/
题目大意:给你一棵节点数为 N 的树,每条树枝有权值,点有黑白两色,问你找一条路径使其进过的黑色的节点数不超过 K 且权值和最大,然后输出这个权值。
思路:继上题的 Tree 之后,漆子超论文的下一道题目,表示看论文、题解和别人代码做了很久。。 = =
用G[ i ][ j ] 表示根节点 的第 i 个儿子经过的黑点数为 j 的最优值,但是 i、j 范围太大了,空间开不下。但是我们不需要保存所有的儿子对应的所有 j ,我们只关心已经算过的节点中每个 j 对应的最大值,所以这里需要优化一下:用f [ i ] 表示已处理的节点中黑点数不超过 i 的最优值,显然,f 具有单调递增性。然后对于当前要处理的节点,算出 g[ i ] ,,然后用g[ i ] 去和 f[ i ]组合更新 ans ,再更新 f[ i ] 就行了,注意:f[ i ] 更新好以后,也要用符合要求的 i 更新 ans ,因为 g[ i ] 有可能不和 f[ i ] 结合,即,根节点为路径的起点或终点。对于每一个根节点,需要对每个儿子按照 dep[ i ] 先进行排序,然后每次只要更新到dep[ i ] 就行了,时间复杂度为排序的复杂度O(NlogN),如果不排序,则最坏情况下,时间复杂度会达到 O(N^2 )。
基本上是照着别人代码来的,先开始一直TLE,然后WA,然后RE,最后才 AC 的。。 = = WA的原因在于,我把 solve()里的 getted[ root ] =0, 写成 getted[ x ] =0 了。。 T T,而TLE是在于找 root 时,我是按照上一题的方法来的,代码是这样的:
int num[MAXN],maxv[MAXN]; vector <int> node; void dfs(int u,int fa) { node.push_back(u); num[u]=1; for(int e = head[u];e!=-1;e= edge[e].next) { int v = edge[e].t; if(getted[v]||v==fa) continue; dfs(v,u); num[u]+=num[v]; maxv[u] = max(maxv[u],num[v]); } } int get_root(int x) { node.clear(); dfs(x,0); int minn = INF; int sum_node = num[x]; int root; for(int i=0;i<node.size();i++) { int cur = node[i]; maxv[cur] = max(maxv[cur],sum_node-num[cur]); if(maxv[cur]<minn) { minn = maxv[cur]; root = cur; } } return root; }然后看了别人的,改成下面这个就过了:
void dfs1(int u,int fa) { father[u]=fa; num[u]=1; maxv[u]=0; for(int e = head[u];e!=-1;e= edge[e].next) { int v = edge[e].t; if(getted[v]||v==fa) continue; dfs1(v,u); num[u]+=num[v]; maxv[u] = max(maxv[u],num[v]); } } int minn; void dfs2(int u,int sum,int& root) { for(int e = head[u] ; e!=-1 ; e =edge[e].next) { int v = edge[e].t; if(getted[v]||v==father[u]) continue; dfs2(v,sum,root); } int tmp = max(sum-num[u],maxv[u]); if(tmp<minn) { minn = tmp; root = u; } } int get_root(int x) { dfs1(x,0); minn = INF; int sum_node = num[x]; int root; dfs2(x,sum_node,root); return root; }
还有,如果把上面那段代码里 dfs2()里和 dfs1()一样加个 fa ,把 dfs1()里的 father 数组去掉,交上去,竟然是RE。。 想不明白啊,想不明白。。
好吧,改来改去,终于AC了,代码如下:
#include<cstdio> #include<cstring> #include<vector> #include<algorithm> using namespace std; const int INF = 0x0fffffff ; const int MAXN = 400022 ; int n,m,k; struct Edge { int t,next,len; } edge[MAXN<<1]; int head[MAXN],tot; void add_edge(int s,int t,int len) { edge[tot].len=len; edge[tot].t=t; edge[tot].next = head[s]; head[s] = tot++; } bool is_black[MAXN],getted[MAXN]; int num[MAXN],maxv[MAXN]; int father[MAXN]; void dfs1(int u,int fa) { father[u]=fa; num[u]=1; maxv[u]=0; for(int e = head[u];e!=-1;e= edge[e].next) { int v = edge[e].t; if(getted[v]||v==fa) continue; dfs1(v,u); num[u]+=num[v]; maxv[u] = max(maxv[u],num[v]); } } int minn; void dfs2(int u,int sum,int& root) { for(int e = head[u] ; e!=-1 ; e =edge[e].next) { int v = edge[e].t; if(getted[v]||v==father[u]) continue; dfs2(v,sum,root); } int tmp = max(sum-num[u],maxv[u]); if(tmp<minn) { minn = tmp; root = u; } } int get_root(int x) { dfs1(x,0); minn = INF; int sum_node = num[x]; int root; dfs2(x,sum_node,root); return root; } int dep[MAXN]; void get_dep(int u,int fa) { dep[u] = is_black[u]; for(int e = head[u];e!=-1;e = edge[e].next) { int v = edge[e].t; if(getted[v]||v==fa) continue; get_dep(v,u); dep[u] = max(dep[u],is_black[u]+dep[v]); } } int g[MAXN]; void get_g(int u,int fa,int s,int c) { g[c] = max(g[c],s); for(int e = head[u] ; e!=-1 ;e=edge[e].next) { int v = edge[e].t; int len = edge[e].len; if(getted[v]||v==fa) continue; get_g(v,u,s+len,c+is_black[v]); } } int id[MAXN]; bool cmp(int a,int b) { return dep[edge[a].t]<dep[edge[b].t]; } int ans; int f[MAXN]; void solve(int x) { int root = get_root(x); //printf("root = %d\n",root); getted[root]=1; for(int e = head[root]; e!=-1 ;e=edge[e].next) { int v = edge[e].t; if(getted[v]) continue; solve(v); } int cc=0; for(int e = head[root];e!=-1; e =edge[e].next) { int v = edge[e].t; if(getted[v]) continue; get_dep(v,root); id[cc++]=e; } sort(id,id+cc,cmp); for(int i=0;i<=dep[edge[id[cc-1]].t];i++) f[i]=-INF; //printf("root = %d\n",root); for(int i=0;i<cc;i++) { int cur = edge[id[i]].t; int len = edge[id[i]].len; for(int j=0;j<=dep[cur];j++) g[j]=-INF; get_g(cur,root,len,is_black[cur]); //printf("cur = %d\n",cur); if(i>0) { int end = min(k - is_black[root],dep[cur]); for(int j=0;j<=end;j++) { int p = min(dep[edge[id[i-1]].t],k-j-is_black[root]); //printf("g[%d] = %d,f[%d] = %d\n",j,g[j],p,f[p]); if(f[p]==-INF) break; if(g[j]!=-INF) ans=max(ans,g[j]+f[p]); } } for(int j=0;j<=dep[cur];j++) { f[j]=max(f[j],g[j]); if(j>0) f[j]=max(f[j],f[j-1]); if(i==0&&j+is_black[root]<=k) ans = max(ans,f[j]); } } //printf("ans = %d\n",ans); getted[root]=0; } int main() { while(~scanf("%d%d%d",&n,&k,&m)) { memset(is_black,0,sizeof(is_black)); int pos; for(int i=0;i<m;i++) { scanf("%d",&pos); is_black[pos]=1; } memset(head,-1,sizeof(head)); tot=0; int a,b,c; for(int i=1;i<n;i++) { scanf("%d%d%d",&a,&b,&c); add_edge(a,b,c); add_edge(b,a,c); } ans=0; memset(getted,0,sizeof(getted)); solve(1); printf("%d\n",ans); } return 0; } /* 7 5 6 2 3 4 5 6 7 1 7 100 1 5 100 5 6 100 1 2 1 2 3 1 3 4 1 */