BZOJ1912: 巡逻 题解

这道题很像topcoder里的一道题kingdomtour,是它的弱化版,可以看我的那道题的博客,一个树型dp,复杂度 O(nk2) O ( n k 2 )

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#define LL long long
#define LB long double
#define x first
#define y second
#define Pair pair
#define pb push_back
#define pf push_front
#define mp make_pair
#define LOWBIT(x) x & (-x)
using namespace std;

const int MOD=998244353;
const LL LINF=2e16;
const int INF=1e9;
const int magic=348;
const double eps=1e-10;

inline int getint()
{
    char ch;int res;bool f;
    while (!isdigit(ch=getchar()) && ch!='-') {}
    if (ch=='-') f=false,res=0; else f=true,res=ch-'0';
    while (isdigit(ch=getchar())) res=res*10+ch-'0';
    return f?res:-res;
}

int n,K;
vector<int> v[100048];

int dp[100048][10],tmp[100048][10];

inline void dfs(int cur,int father)
{
    int i,j,k,y,cc=0;
    for (i=0;i<int(v[cur].size());i++)
    {
        y=v[cur][i];
        if (y!=father) dfs(y,cur);
    }
    for (i=0;i<=K*2;i++) tmp[cc][i]=0;
    for (i=0;i<int(v[cur].size());i++)
    {
        y=v[cur][i];
        if (y!=father)
        {
            cc++;
            for (j=0;j<=K*2;j++) tmp[cc][j]=INF;
            for (j=0;j<=K*2;j++)
                for (k=0;k<=j;k++)
                    tmp[cc][j]=min(tmp[cc][j],tmp[cc-1][j-k]+dp[y][k]+((k&1)?1:2));
        }
    }
    for (i=0;i<=K*2;i++) dp[cur][i]=tmp[cc][i];
}

int main ()
{
    int i,x,y;
    n=getint();K=getint();
    for (i=1;i<=n-1;i++)
    {
        x=getint();y=getint();
        v[x].pb(y);v[y].pb(x);
    }
    dfs(1,-1);
    printf("%d\n",dp[1][K*2]+K);
    return 0;
}

上网看了一些题解,发现k比较小的时候,有一种找直径的算法非常巧妙
先考虑k=1的情况
我们发现,连接u,v两个点之后,u~v的路径上的点都只会被访问一次,所以要使得减少的边最多,我们应该找树的直径,这个比较显然
再考虑k=2的情况,我们发现,如果两条链有公共部分,那么这个公共部分的边还是要走两次的
进而我们发现,如果两条链有公共部分,那么一定可以在使答案不变坏的前提下换一种方案,使得两条链没有公共部分,大概是现在的第2条链的一半和第1条链的一半接起来形成新链之类的
于是我们这样做:
先对原树找一条直径
然后把直径上的边的边权从1改成-1
然后再对原树找一条直径
这两次的答案合起来就是最优方案
我们发现这个把1改成-1的操作很像网络流里面的反向边,选择这条边相当于把原来选的直径退掉

*注意当一棵树内的边有负权的时候,不能通过两次dfs找直径,得老老实实写个树型dp(其实反而更好写?雾)

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#define LL long long
#define LB long double
#define x first
#define y second
#define Pair pair
#define pb push_back
#define pf push_front
#define mp make_pair
#define LOWBIT(x) x & (-x)
using namespace std;

const int MOD=998244353;
const LL LINF=2e16;
const int INF=1e9;
const int magic=348;
const double eps=1e-10;

inline int getint()
{
    char ch;int res;bool f;
    while (!isdigit(ch=getchar()) && ch!='-') {}
    if (ch=='-') f=false,res=0; else f=true,res=ch-'0';
    while (isdigit(ch=getchar())) res=res*10+ch-'0';
    return f?res:-res;
}

int n,k;
vector<int> v[100048];

struct Edge
{
    int x,y;
    int len;
}edge[100048];

int ans=0;

int sum[100048],fa[100048],faind[100048];

inline int Get(int ind,int cur)
{
    if (edge[ind].x==cur) return edge[ind].y; else return edge[ind].x;
}

inline void dfs(int cur,int father)
{
    int i,y;
    for (i=0;i<int(v[cur].size());i++)
    {
        y=Get(v[cur][i],cur);
        if (y!=father)
        {
            sum[y]=sum[cur]+edge[v[cur][i]].len;
            fa[y]=cur;faind[y]=v[cur][i];
            dfs(y,cur);
        }
    }
}

inline int find_dia()
{
    sum[1]=0;fa[1]=-1;dfs(1,-1);
    int maxn=-INF,maxpos,i;
    for (i=1;i<=n;i++)
        if (sum[i]>maxn)
        {
            maxn=sum[i];
            maxpos=i;
        }
    sum[maxpos]=0;fa[maxpos]=-1;dfs(maxpos,-1);
    maxn=-INF;
    for (i=1;i<=n;i++)
        if (sum[i]>maxn)
        {
            maxn=sum[i];
            maxpos=i;
        }
    return maxpos;
}

inline void update(int cur)
{
    while (fa[cur]!=-1)
    {
        edge[faind[cur]].len=-1;
        cur=fa[cur];
    }
}

int res=0,dp[100048];
inline void Dfs(int cur,int father)
{
    int i,y;dp[cur]=0;
    for (i=0;i<int(v[cur].size());i++)
    {
        y=Get(v[cur][i],cur);
        if (y!=father)
        {
            Dfs(y,cur);
            res=max(res,dp[y]+edge[v[cur][i]].len+dp[cur]);
            dp[cur]=max(dp[cur],dp[y]+edge[v[cur][i]].len);
        }
    }
}

int main ()
{
    int i,x,y,ans=0;
    n=getint();k=getint();
    for (i=1;i<=n-1;i++)
    {
        x=getint();y=getint();
        v[x].pb(i);v[y].pb(i);
        edge[i]=Edge{x,y,1};
    }
    int ed=find_dia();
    ans+=sum[ed];
    if (k==1) {printf("%d\n",(n-1)*2-ans+1);return 0;}
    update(ed);
    Dfs(1,-1);ans+=res;
    printf("%d\n",(n-1)*2-ans+2);
    return 0;
}

你可能感兴趣的:(树型dp)