BZOJ 2654 tree 二分答案+Kruskal

题目大意:给定一张带权无向图,每条边有一个颜色(黑色/白色),求一棵生成树满足有 need 条白色边且权值和最小
二分一个 x ,然后将所有白边权值加上 x ,跑两遍Kruskal,第一遍白边排在前面,第二遍黑边排在前面,这样可以求出当前白边数量的最大最小值
如果 need 在最大最小值之间那么直接输出结果,否则如果小于最小值就增大 x ,大于最大值就减小 x
然而我并不会证明正确性。。。

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define M 100100
using namespace std;
struct edge{
    int x,y,z,col;
    friend istream& operator >> (istream &_,edge &e)
    {
        return scanf("%d%d%d%d",&e.x,&e.y,&e.z,&e.col),++e.x,++e.y,_;
    }
}edges[M];
bool Compare1 (const edge &e1,const edge &e2)//白边优先
{
    if( e1.z!=e2.z )
        return e1.z<e2.z;
    return e1.col<e2.col;
}
bool Compare2 (const edge &e1,const edge &e2)//黑边优先
{
    if( e1.z!=e2.z )
        return e1.z<e2.z;
    return e1.col>e2.col;
}
int n,m,need;
namespace Union_Find_Set{
    int fa[M],rank[M];
    void Initialize()
    {
        memset(fa,0,sizeof fa);
        memset(rank,0,sizeof rank);
    }
    int Find(int x)
    {
        if(!fa[x]||fa[x]==x)
            return fa[x]=x;
        return fa[x]=Find(fa[x]);
    }
    void Union(int x,int y)
    {
        x=Find(x);y=Find(y);
        if(x==y) return ;
        if(rank[x]>rank[y])
            swap(x,y);
        if(rank[x]==rank[y])
            ++rank[y];
        fa[x]=y;
    }
}
bool Check(int x)//need小于最少边数返回1 大于最大反回0
{
    using namespace Union_Find_Set;
    int i,min_cnt=0,max_cnt=0,re=0;
    for(i=1;i<=m;i++)
        if(edges[i].col==0)
            edges[i].z+=x;
    Initialize();
    sort(edges+1,edges+m+1,Compare1);
    for(i=1;i<=m;i++)
    {
        int x=Find(edges[i].x);
        int y=Find(edges[i].y);
        if(x==y) continue;
        Union(x,y);
        if(edges[i].col==0)
            ++max_cnt;
    }
    Initialize();
    sort(edges+1,edges+m+1,Compare2);
    for(i=1;i<=m;i++)
    {
        int x=Find(edges[i].x);
        int y=Find(edges[i].y);
        if(x==y) continue;
        Union(x,y);
        re+=edges[i].z;
        if(edges[i].col==0)
            ++min_cnt;
    }
    for(i=1;i<=m;i++)
        if(edges[i].col==0)
            edges[i].z-=x;
    if(need<min_cnt)
        return 1;
    if(need>max_cnt)
        return 0;
    throw re-need*x;
}
void Bisection()
{
    int l=-101,r=101;
    while(r-l>1)
    {
        int mid=l+r>>1;
        if( Check(mid) )
            l=mid;
        else
            r=mid;
    }
    Check(l);
    Check(r);
}
int main()
{
    int i;
    cin>>n>>m>>need;
    for(i=1;i<=m;i++)
        cin>>edges[i];
    try
    {
        Bisection();
    }
    catch(int ans)
    {
        cout<<ans<<endl;
        return 0;
    }
    printf("%d\n",1/0);
    return 0;
}

你可能感兴趣的:(kruskal,bzoj,二分答案,BZOJ2654)