bzoj 1912 tree_dp

  这道题我们加一条路可以减少的代价为这条路两端点到lca的路径的长度,相当于一条链,那么如果加了两条链的话,这两条链重复的部分还是要走两遍,反而对答案没有了贡献(其实这个可以由任意两条链都可以看成两条不重叠的链来证明),那么这道题k=2的时候就转化为了求出树上两条链,使得两条链不重叠的长度最大,那么答案就是(n-1)<<1-SumLen+2.当k=1的时候我们直接求出来树的最长链然后减去就好了,这个在此不再赘述。

  对于树上两链不重复部分最大我们是可以tree_dp的,设w[i][0..4]来表示当前以i为根的子树中选取了0/1/2条链的最大值,同时我们保留了一个3,4来记录以i为一端点的最长链,同时选取了0/1条最长链的最大值,这样直接转移就好了。

  我写的是另外一种方法,先找出最长链,然后将最长链上的边长设为-1,然后再找一次最长链,这样求出来的就是答案。

  反思:开始没意识到第二次最长链不能用两边bfs,所以果断的写了bfs,后来才发现的,又临时加了一个tree_dp,因为加的路必须选,所以我们要将每个点的最长和次长链设为-inf,叶子节点的为0,然后用非叶子节点更新答案,然后竟然1A,真是感动= =。

/**************************************************************

    Problem: 1912

    User: BLADEVIL

    Language: C++

    Result: Accepted

    Time:1268 ms

    Memory:5884 kb

****************************************************************/

 

//By BLADEVIL

#include <cstdio>

#include <cstring>

#include <algorithm>

#define maxn 100010

#define maxm 200020

#define inf (~0U>>1)

 

using namespace std;

 

int n,k,l;

int pre[maxm],other[maxm],last[maxn],len[maxm];

int que[maxn],dis[maxn],father[maxn],flag[maxn],max_1[maxn],max_2[maxn];

 

void connect(int x,int y) {

    pre[++l]=last[x];

    last[x]=l;

    other[l]=y;

    len[l]=1;

}

 

void bfs(int x) {

    memset(que,0,sizeof que);

    memset(dis,0,sizeof dis);

    memset(father,0,sizeof father);

    memset(flag,0,sizeof flag);

    int h=0,t=1;

    que[1]=x; dis[x]=1; flag[x]=1;

    while (h<t) {

        int cur=que[++h];

        for (int p=last[cur];p;p=pre[p]) {

            if (flag[other[p]]) continue;

            father[other[p]]=p;

            dis[other[p]]=dis[cur]+len[p];

            flag[other[p]]=1;

            que[++t]=other[p];

        }

    }

}

 

int tree_dp() {

    int ans=-inf;

    memset(que,0,sizeof que);

    memset(flag,0,sizeof flag);

    memset(dis,0,sizeof dis);

    memset(max_1,-128,sizeof max_1);

    memset(max_2,-128,sizeof max_2);

    int h=0,t=1;

    que[1]=1; flag[1]=1; dis[1]=1;

    while (h<t) {

        int cur=que[++h];

        for (int p=last[cur];p;p=pre[p]) {

            if (flag[other[p]]) continue;

            que[++t]=other[p]; flag[other[p]]=1; dis[other[p]]=dis[cur]+1;

        }

    }

    //for (int i=1;i<=n;i++) printf("%d ",que[i]); printf("\n");

    for (int i=n;i;i--) {

        int cur=que[i];

        for (int p=last[cur];p;p=pre[p]) {

            if (dis[other[p]]<dis[cur]) continue;

            if (max_1[other[p]]+len[p]>max_1[cur])

                max_2[cur]=max_1[cur],max_1[cur]=max_1[other[p]]+len[p]; else

            if (max_1[other[p]]+len[p]>max_2[cur])

                max_2[cur]=max_1[other[p]]+len[p];

        }

        if (max_1[cur]<-100000000) max_1[cur]=max_2[cur]=0; else ans=max(ans,max(max_1[cur]+max_2[cur],max_1[cur]));

    }

    //for (int i=1;i<=n;i++) printf("|%d %d\n",max_1[i],max_2[i]);

    return ans;

}

 

int getmax() {

    int s=0;

    for (int i=1;i<=n;i++) if (dis[i]>dis[s]) s=i;

    return s;

}

 

int main() {

    scanf("%d%d",&n,&k); l=1;

    for (int i=1;i<n;i++) {

        int x,y; scanf("%d%d",&x,&y);

        connect(x,y); connect(y,x);

    }

    bfs(1); bfs(getmax());

    if (k==1) {

        printf("%d\n",2*n-dis[getmax()]);

        return 0;

    }

    int cur=getmax(),ans=dis[cur]-2;

    while (father[cur]) len[father[cur]]=len[father[cur]^1]=-1,cur=other[father[cur]^1];

    ans+=tree_dp()-1;

    printf("%d\n",2*n-2-ans);

    return 0;

}

 

你可能感兴趣的:(tree)