spoj 1825 (树的分治)

spoj 1825

题目:http://www.spoj.com/problems/FTOUR2/

题目大意:给你一棵节点数为 N 的树,每条树枝有权值,点有黑白两色,问你找一条路径使其进过的黑色的节点数不超过 K 且权值和最大,然后输出这个权值。

思路:继上题的 Tree 之后,漆子超论文的下一道题目,表示看论文、题解和别人代码做了很久。。 = =

用G[ i ][ j ] 表示根节点 的第 i 个儿子经过的黑点数为 j 的最优值,但是 i、j 范围太大了,空间开不下。但是我们不需要保存所有的儿子对应的所有 j ,我们只关心已经算过的节点中每个 j 对应的最大值,所以这里需要优化一下:用f [ i ]  表示已处理的节点中黑点数不超过 i 的最优值,显然,f 具有单调递增性。然后对于当前要处理的节点,算出 g[ i ] ,,然后用g[ i ] 去和 f[ i ]组合更新 ans ,再更新 f[ i ] 就行了,注意:f[ i ]  更新好以后,也要用符合要求的 i 更新 ans ,因为 g[ i ] 有可能不和 f[ i ] 结合,即,根节点为路径的起点或终点。对于每一个根节点,需要对每个儿子按照 dep[ i ] 先进行排序,然后每次只要更新到dep[ i ] 就行了,时间复杂度为排序的复杂度O(NlogN),如果不排序,则最坏情况下,时间复杂度会达到  O(N^2 )。

基本上是照着别人代码来的,先开始一直TLE,然后WA,然后RE,最后才 AC 的。。 = =  WA的原因在于,我把 solve()里的 getted[ root ] =0, 写成 getted[ x ] =0 了。。 T T,而TLE是在于找 root 时,我是按照上一题的方法来的,代码是这样的:

int num[MAXN],maxv[MAXN];

vector <int> node;

void dfs(int u,int fa)
{
    node.push_back(u);
    num[u]=1;
    for(int e = head[u];e!=-1;e= edge[e].next)
    {
        int v = edge[e].t;
        if(getted[v]||v==fa) continue;
        dfs(v,u);
        num[u]+=num[v];
        maxv[u] = max(maxv[u],num[v]);
    }
}

int get_root(int x)
{
    node.clear();
    dfs(x,0);
    int minn = INF;
    int sum_node = num[x];
    int root;
    for(int i=0;i<node.size();i++)
    {
        int cur = node[i];
        maxv[cur] = max(maxv[cur],sum_node-num[cur]);
        if(maxv[cur]<minn)
        {
            minn = maxv[cur];
            root = cur;
        }
    }
    return root;
}
然后看了别人的,改成下面这个就过了:

void dfs1(int u,int fa)
{
    father[u]=fa;
    num[u]=1;
    maxv[u]=0;
    for(int e = head[u];e!=-1;e= edge[e].next)
    {
        int v = edge[e].t;
        if(getted[v]||v==fa) continue;
        dfs1(v,u);
        num[u]+=num[v];
        maxv[u] = max(maxv[u],num[v]);
    }
}

int minn;

void dfs2(int u,int sum,int& root)
{
    for(int e = head[u] ; e!=-1 ; e =edge[e].next)
    {
        int v = edge[e].t;
        if(getted[v]||v==father[u]) continue;
        dfs2(v,sum,root);
    }
    int tmp = max(sum-num[u],maxv[u]);
    if(tmp<minn)
    {
        minn = tmp;
        root = u;
    }
}

int get_root(int x)
{
    dfs1(x,0);
    minn = INF;
    int sum_node = num[x];
    int root;
    dfs2(x,sum_node,root);
    return root;
}

这复杂度不是一样的嘛,想不清楚。。= =

还有,如果把上面那段代码里 dfs2()里和 dfs1()一样加个 fa ,把 dfs1()里的 father 数组去掉,交上去,竟然是RE。。 想不明白啊,想不明白。。

好吧,改来改去,终于AC了,代码如下:

#include<cstdio>
#include<cstring>
#include<vector>
#include<algorithm>
using namespace std;

const int INF = 0x0fffffff ;

const int MAXN = 400022 ;

int n,m,k;

struct Edge
{
    int t,next,len;
} edge[MAXN<<1];

int head[MAXN],tot;

void add_edge(int s,int t,int len)
{
    edge[tot].len=len;
    edge[tot].t=t;
    edge[tot].next = head[s];
    head[s] = tot++;
}

bool is_black[MAXN],getted[MAXN];

int num[MAXN],maxv[MAXN];

int father[MAXN];

void dfs1(int u,int fa)
{
    father[u]=fa;
    num[u]=1;
    maxv[u]=0;
    for(int e = head[u];e!=-1;e= edge[e].next)
    {
        int v = edge[e].t;
        if(getted[v]||v==fa) continue;
        dfs1(v,u);
        num[u]+=num[v];
        maxv[u] = max(maxv[u],num[v]);
    }
}

int minn;

void dfs2(int u,int sum,int& root)
{
    for(int e = head[u] ; e!=-1 ; e =edge[e].next)
    {
        int v = edge[e].t;
        if(getted[v]||v==father[u]) continue;
        dfs2(v,sum,root);
    }
    int tmp = max(sum-num[u],maxv[u]);
    if(tmp<minn)
    {
        minn = tmp;
        root = u;
    }
}

int get_root(int x)
{
    dfs1(x,0);
    minn = INF;
    int sum_node = num[x];
    int root;
    dfs2(x,sum_node,root);
    return root;
}

int dep[MAXN];

void get_dep(int u,int fa)
{
    dep[u] = is_black[u];
    for(int e = head[u];e!=-1;e = edge[e].next)
    {
        int v = edge[e].t;
        if(getted[v]||v==fa) continue;
        get_dep(v,u);
        dep[u] = max(dep[u],is_black[u]+dep[v]);
    }
}

int g[MAXN];

void get_g(int u,int fa,int s,int c)
{
    g[c] = max(g[c],s);
    for(int e = head[u] ; e!=-1 ;e=edge[e].next)
    {
        int v = edge[e].t;
        int len = edge[e].len;
        if(getted[v]||v==fa) continue;
        get_g(v,u,s+len,c+is_black[v]);
    }
}

int id[MAXN];

bool cmp(int a,int b)
{
    return dep[edge[a].t]<dep[edge[b].t];
}

int ans;

int f[MAXN];

void solve(int x)
{
    int root = get_root(x);
    //printf("root = %d\n",root);
    getted[root]=1;
    for(int e = head[root]; e!=-1 ;e=edge[e].next)
    {
        int v = edge[e].t;
        if(getted[v]) continue;
        solve(v);
    }
    int cc=0;
    for(int e = head[root];e!=-1; e =edge[e].next)
    {
        int v = edge[e].t;
        if(getted[v]) continue;
        get_dep(v,root);
        id[cc++]=e;
    }
    sort(id,id+cc,cmp);
    for(int i=0;i<=dep[edge[id[cc-1]].t];i++)
        f[i]=-INF;
    //printf("root = %d\n",root);
    for(int i=0;i<cc;i++)
    {
        int cur = edge[id[i]].t;
        int len = edge[id[i]].len;
        for(int j=0;j<=dep[cur];j++)
            g[j]=-INF;
        get_g(cur,root,len,is_black[cur]);
        //printf("cur = %d\n",cur);
        if(i>0)
        {
            int end = min(k - is_black[root],dep[cur]);
            for(int j=0;j<=end;j++)
            {
                int p = min(dep[edge[id[i-1]].t],k-j-is_black[root]);
                //printf("g[%d] = %d,f[%d] = %d\n",j,g[j],p,f[p]);
                if(f[p]==-INF) break;
                if(g[j]!=-INF)
                    ans=max(ans,g[j]+f[p]);
            }
        }
        for(int j=0;j<=dep[cur];j++)
        {
            f[j]=max(f[j],g[j]);
            if(j>0) f[j]=max(f[j],f[j-1]);
            if(i==0&&j+is_black[root]<=k) ans = max(ans,f[j]);
        }
    }
    //printf("ans = %d\n",ans);
    getted[root]=0;
}

int main()
{
    while(~scanf("%d%d%d",&n,&k,&m))
    {
        memset(is_black,0,sizeof(is_black));
        int pos;
        for(int i=0;i<m;i++)
        {
            scanf("%d",&pos);
            is_black[pos]=1;
        }
        memset(head,-1,sizeof(head));
		tot=0;
        int a,b,c;
        for(int i=1;i<n;i++)
        {
            scanf("%d%d%d",&a,&b,&c);
            add_edge(a,b,c);
            add_edge(b,a,c);
        }
        ans=0;
        memset(getted,0,sizeof(getted));
        solve(1);
        printf("%d\n",ans);
    }
    return 0;
}

/*
7 5 6
2
3
4
5
6
7
1 7 100
1 5 100
5 6 100
1 2 1
2 3 1
3 4 1
*/


你可能感兴趣的:(spoj 1825 (树的分治))