【MST-缩图】BZOJ3206 [APIO2013]道路费用

【题目】
原题地址
题目大意:给定一幅无向图, 每条边有一个边权,一个人走过边时要交的税就是这条边边权。每个点有 pi p i 个人,他们要去到1号点搞活动。现在你拥有其中 k k 条边的控制权,即你可以任意改变这些边边权。在你决定边权后,你可以决定这幅图的一棵最小生成树,每个点上的人只能走这些边到1号点。求你最多能得到多少的税款。

【题目分析】
又是一道很巧妙的题目,关键点在于看清楚题目。
一眼看过去很不可做,不过k很小,我们可以考虑像动态维护MST的时候一样,将图缩小。

【解题思路】
首先注意到每条边的边权不同,因此在没有新边时,MST是固定的。
又发现新边很少,回想起动态维护MST时我们的做法,我们可以考虑缩图。
这样将新边设为 INF − I N F ,我们做MST后,在MST上的原图边一定会出现在我们最后的选择中。
接着将所有新边去掉进行缩点,图变成了至多 k+1 k + 1 个联通块,同时我们可以缩边,这样最多有400条边。
用这些边再做一次MST,就能将边数缩成 k k

接下来我们要做的就是确定新边的边权,但是发现我们不一定选新边建MST会最优,所以我们还要枚举一个边集。
考虑枚举边集后我们结合原图中的边作MST,那么如果我们新加进一条边时,如果形成了一个环,则这条边会对环上所有未确定边权的边产生一个限制。
这样如果我们暴力更新环上所有边的话的话,后半部分的复杂度是 O(2kk2) O ( 2 k ∗ k 2 ) 的,算起来过不了,实际上能过。

但实际上这个做法还可以得到优化,我们发现要优化关键在于减少边权的更新。
我们可以对所有原图边按边权从小到大进行编号,这样如果我们知道一条新边会被哪些原图边算入环,我们就可以O1得到限制(lowbit一下)。
然后对于每条原图不在MST上的边 (u,v) ( u , v ) ,我们在 u u v v 上都打一个编号标记,
对于每个节点,我们知道它有一些标记,那么它的父边就会被这些标记代表的边贡献到。
回溯时,我们将这个节点的标记异或上所有它儿子节点的标记就可以了(可以画图验证)
这样子后半部分的复杂度就是 O(2kk2) O ( 2 k ∗ k 2 ) ,可以完美通过此题。

对了前面缩图,MST之类的复杂度应该就是 O(mlogm) O ( m l o g m ) 之类的吧,不会超过后面的复杂度。

具体细节可以看代码。

【一些奇奇怪怪的神秘错误】
我是和机房的sc大佬一起写这道题目的。但是我的程序可能因为某种不知名的程序规范,开启O2后答案会错误,但是不开O2是正确的。
sc大佬的程序十分正常,但是过不了BZOJ有边权相同的数据。
这时候问题来了,我的程序把BZOJ的数据过了。。。然而BZOJ应该是自带O2的,这就十分玄学了。
交到UOJ上我是过不了的,sc是AC的。
当然洛谷上数据貌似有错误,因此我们都过不了。

下面的代码仅供参考借鉴实现,暴力计算边权的程序是从网上扒下来的。

【参考代码】

我的代码:(BZOJ-AC)

#include
using namespace std;

typedef long long LL;
const int INF=1e9;
const int P=25;
const int N=1e5+10;
const int M=6e5+10;
int n,m,tot,K,cnt,ps,ns,rway;
int tar[P],fa[N],head[N],bl[N];
int mp[P][P],bo[P][P],Log2[4194306];
vector<int>es[P];
LL tmpans,ans;
LL peo[P],p[N];

struct Tway
{
    int v,nex,w;
};
Tway e[M];

struct Tnode
{
    int x,y,w,op;
};
Tnode a[M],b[P],c[P],rem[P];

void add(int u,int v,int w)
{
    e[++tot]=(Tway){v,head[u],w};head[u]=tot;
    e[++tot]=(Tway){u,head[v],w};head[v]=tot;
}

bool cmp(Tnode A,Tnode B)
{
    return A.wbool cmpop(Tnode A,Tnode B)
{
    return A.op>B.op;
}

int read()
{
    int ret=0,f=1;char c=getchar();
    while(!isdigit(c)){if(c=='-')f=0;c=getchar();}
    while(isdigit(c)){ret=(ret<<1)+(ret<<3)+(c^48);c=getchar();}
    return f?ret:-ret;
}

int findf(int x)
{
    return fa[x]==x?x:fa[x]=findf(fa[x]);
}

void init()
{
    Log2[1]=0;
    for(int i=2;i<4194305;++i)
        Log2[i]=Log2[i>>1]+1;
    n=read();m=read();K=read();
    for(int i=1;i<=m;++i)
        a[i].x=read(),a[i].y=read(),a[i].w=read();
    for(int i=1;i<=K;++i)
        b[i].x=read(),b[i].y=read();
    for(int i=1;i<=n;++i)
        p[i]=read();

    for(int i=1;i<=n;++i)
        fa[i]=i;
    sort(a+1,a+m+1,cmp);
    for(int i=1;i<=K;++i)
    {
        int fx=findf(b[i].x),fy=findf(b[i].y);
        if(fx!=fy)
            fa[fx]=fy;
    }

    for(int i=1;i<=m;++i)
    {
        int fx=findf(a[i].x),fy=findf(a[i].y);
        if(fx!=fy)
            fa[fx]=fy,a[i].op=1;
    }
}

void dfs(int x)
{
    peo[cnt]+=p[x];bl[x]=cnt;
    for(int i=head[x];i;i=e[i].nex)
    {
        int v=e[i].v;
        if(!bl[v])
            dfs(v);
    }
}

void getmp()
{
    sort(a+1,a+m+1,cmpop);
    for(int i=1;i<=m && a[i].op;++i)
        add(a[i].x,a[i].y,0);
    for(int i=1;i<=n;++i)
        if(!bl[i])
            ++cnt,dfs(i);

    for(int i=1;i<=cnt;++i)
        for(int j=1;j<=cnt;++j)
            mp[i][j]=mp[j][i]=INF;
    for(int i=1;i<=m;++i)
    {
        int x=bl[a[i].x],y=bl[a[i].y];
        if(x==y)    
            continue;
        mp[x][y]=mp[y][x]=min(mp[x][y],a[i].w);
    }

    tot=0;
    for(int i=0;i0;
    for(int i=1;ifor(int j=i+1;j<=cnt;++j)
            if(mp[i][j]!=INF)
                ++ps,c[ps].x=i,c[ps].y=j,c[ps].w=mp[i][j];

    for(int i=0;i1,c+ps+1,cmp);
    for(int i=1;i<=ps;++i)
    {
        int fx=findf(c[i].x),fy=findf(c[i].y);
        if(fx!=fy)
            c[++ns]=c[i],fa[fx]=fy;
    }
}

int lowbit(int x)
{
    return x&(-x);
}

LL dfs2(int x,int f)
{
    LL sum=peo[x];
    for(int i=0;i<(int)es[x].size();++i)
    {
        int v=es[x][i];
        if(v!=f)
            sum+=dfs2(v,x); 
    }
    if(x!=bl[1])
    {
        tar[f]^=tar[x];
        if(bo[x][f])
        {
            int p=Log2[lowbit(tar[x])];
            tmpans+=1ll*sum*c[p+1].w;
        }
    }
    return sum;
}

void solve()
{
    for(int i=1;i<=K;++i)
        b[i].x=bl[b[i].x],b[i].y=bl[b[i].y];
    for(int sta=1;sta<(1<for(int i=0;i<=cnt+1;++i)   
            fa[i]=i,es[i].clear(),tar[i]=0;
        rway=0;
        bool flag=true;
        for(int i=1;i<=K;++i)
        {
            if(!(sta&(1<<(i-1))))
                continue;
            int fx=findf(b[i].x),fy=findf(b[i].y);
            bo[b[i].x][b[i].y]=bo[b[i].y][b[i].x]=i;rem[++rway]=b[i];
            if(fx!=fy)
                fa[fx]=fy,es[b[i].x].push_back(b[i].y),es[b[i].y].push_back(b[i].x);
            else
                flag=false;
        }
        if(!flag)
        {
            for(int i=1;i<=rway;++i)
                bo[rem[i].x][rem[i].y]=bo[rem[i].y][rem[i].x]=0;
            continue;
        }
        for(int i=1;i<=ns;++i)
        {
            int fx=findf(c[i].x),fy=findf(c[i].y);
            if(fx!=fy)
            {
                fa[fx]=fy;
                es[c[i].x].push_back(c[i].y),es[c[i].y].push_back(c[i].x);
            }
            else
                tar[c[i].x]|=(1<<(i-1)),tar[c[i].y]|=(1<<(i-1));
        }
        tmpans=0;
        dfs2(bl[1],bl[1]);
        ans=max(ans,tmpans);
        for(int i=1;i<=rway;++i)
            bo[rem[i].x][rem[i].y]=bo[rem[i].y][rem[i].x]=0;
    }
    printf("%lld\n",ans);
}

int main()
{
    freopen("BZOJ3206.in","r",stdin);
    freopen("BZOJ3206.out","w",stdout);

    init();
    getmp();
    solve();

    return 0;
}

sc的代码:(UOJ-AC)

#include
using namespace std;
typedef long long ll;

const int maxn = 6e+5 + 10;
const int maxk = 20 + 5;
const int INF = 0x3fffffff;

int n,m,k,ta[maxn];ll a[maxk];
int lab[maxn],labn;
int llg2[maxn];

struct edge{
    int u,v,d;
}e[maxn],ek[maxk];

struct DSU{
    int f[maxn];

    void init(int x){memset(f,0,x<<2);}

    int fa(int x)
    {
        if(!f[x]) return x;
        return f[x]=fa(f[x]);
    }

    bool merge(int x,int y)
    {
        x=fa(x);y=fa(y);
        if(x==y) return false;
        if(xreturn true;
    }

    void relable()
    {
        int i;
        labn=1;memset(a,0,sizeof(a));
        for(i=1;i<=n;i++) if(!f[i]) lab[i]=labn++;
        for(i=1;i<=n;i++) lab[i]=lab[fa(i)],a[lab[i]]+=ta[i];//,printf("%d in lab:%d\n",i,lab[i]);
    }
}L,K;

inline int read()
{
    int x=0;char c=getchar();
    for(;c<'0'||c>'9';c=getchar());
    for(;c>='0'&&c<='9';c=getchar()) x=(x<<3)+(x<<1)+(c^48);
    return x;
}

bool cmpd(edge a,edge b){return a.dvoid input()
{
    int i;

    n=read();m=read();k=read();
    for(i=0;ifor(i=m;i1;
    }
    for(i=1;i<=n;i++) ta[i]=read();
    for(i=0;i<20;i++) llg2[1<void build()
{
    int i,j;

    sort(e,e+m+k,cmpd);
    for(i=0;iif(K.merge(e[i].u,e[i].v)) if(~e[i].d)
        {
            L.merge(e[i].u,e[i].v);
            e[i].d=INF;
        }
        if(!~e[i].d) e[i].d=INF;
    }
    L.relable();

    sort(e,e+m+k,cmpd);K.init(n+1);

    j=0;
    for(i=0;i//printf("visit: %d-%d d=%d\n",e[i].u,e[i].v,e[i].d);
        if(K.merge(lab[e[i].u],lab[e[i].v])) e[j++]=e[i];//,printf("select: %d-%d d=%d\n",e[i].u,e[i].v,e[i].d);
    }
    m=j;n=labn;
}

int tag[maxk];ll ans,siz[maxk];
vectorint,bool> > nxt[maxk];
inline void addedge(int u,int v,bool type)
{//printf("addedge %d-%d tp=%d\n",u,v,type);
    nxt[u].push_back(make_pair(v,type));
    nxt[v].push_back(make_pair(u,type));
}

void dfs(int x,int f)
{
    int i,v,j,d;

    siz[x]=a[x];
    for(i=0;iif((v=nxt[x][i].first)!=f)
    {
        dfs(v,x);tag[x]^=tag[v];siz[x]+=siz[v];
        if(nxt[x][i].second)
        {
            j=llg2[tag[v]&-tag[v]];//cerr<
            d=e[j].d;
            ans+=siz[v]*d;
        }
    }
}

void solve()
{
    int S,i;ll fans=0;

    //printf("n=%d\n",n);
    //for(i=0;i
    //for(i=0;i

    for(S=0;S<(1<memset(tag,0,sizeof(tag));
        for(i=1;ifor(i=0;iif(S>>i&1)
        {
            if(!K.merge(lab[ek[i].u],lab[ek[i].v])) goto Fail;
            addedge(lab[ek[i].u],lab[ek[i].v],true);
        }
        for(i=0;iif(!K.merge(lab[e[i].u],lab[e[i].v])) tag[lab[e[i].u]]|=1<1<else addedge(lab[e[i].u],lab[e[i].v],false);
        }

        ans=0;
        dfs(1,-1);//printf("S=%d ans=%lld\n",S,ans);
        fans=max(fans,ans);

        Fail:;
    }
    cout<int main()
{
    freopen("lg3639.in","r",stdin);
    freopen("lg3639.out","w",stdout);

    input();
    build();
    solve();

    return 0;
}

网上的暴力修改代码(BZOJ-AC、UOJ-AC):

#include
using namespace std;

typedef long long LL;
const int INF=1e9;
const int N=1e5+10;
const int M=3e5+10;
const int P=25;

int n,m,K,top,cnt,st;
int fa[N],fa2[N],p[N];
int po[P],dep[N],last[N],mn[N];
LL ans,val[N],sum[N];
bool mark[M];

LL read()
{
    LL x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}

struct Tnode
{
    int u,v,w;
}e[M],ne[P],q[M];

bool operator<(Tnode a,Tnode b)
{
    return a.wint to,next;
}ed[P<<1];


void insert(int u,int v)
{
    ed[++cnt]=(Tway){v,last[u]};last[u]=cnt;
    ed[++cnt]=(Tway){u,last[v]};last[v]=cnt;
}
int find(int x)
{
    return x==fa[x]?x:fa[x]=find(fa[x]);
}

int find2(int x)
{
    return x==fa2[x]?x:fa2[x]=find2(fa2[x]);
}

void dp(int x)
{
    sum[x]=val[x];
    for(int i=last[x];i;i=ed[i].next)
        if(ed[i].to!=fa2[x])
        {
            dep[ed[i].to]=dep[x]+1;
            fa2[ed[i].to]=x;
            dp(ed[i].to);
            sum[x]+=sum[ed[i].to];
        }
}

void solve()
{
    cnt=0;
    for(int i=1;i<=K+1;i++)
    {
        int p=po[i];
        last[p]=fa2[p]=0;
        fa[p]=p;mn[p]=INF;
    }
    for(int i=1;i<=K;i++)
        if(mark[i])
        {
            int x=find(ne[i].u),y=find(ne[i].v);
            if(x==y)return;
            fa[x]=y;
            insert(ne[i].u,ne[i].v);
        }
    for(int i=1;i<=K;i++)
    {
        int x=find(q[i].u),y=find(q[i].v);
        if(x!=y)fa[x]=y,insert(q[i].u,q[i].v);
    }

    dp(st);
    for(int i=1;i<=K;i++)
    {
        int u=q[i].u,v=q[i].v;
        if(dep[u]>dep[v])swap(u,v);
        while(dep[v]!=dep[u])mn[v]=min(mn[v],q[i].w),v=fa2[v];
        while(u!=v)
        {
            mn[v]=min(mn[v],q[i].w);
            mn[u]=min(mn[u],q[i].w);
            u=fa2[u];v=fa2[v];
        }
    }

    LL inc=0;
    for(int i=1;i<=K;i++)
        if(mark[i])
        {
            int u=ne[i].u,v=ne[i].v;
            if(dep[u]>dep[v])swap(u,v);
            inc+=mn[v]*sum[v];
        }
    ans=max(inc,ans);
}

void dfs(int x)
{
    if(x==K+1)
    {
        solve();
        return;
    }
    mark[x]=0;dfs(x+1);
    mark[x]=1;dfs(x+1);
}

void init()
{
    n=read();m=read();K=read();
    for(int i=1;i<=m;i++)
        e[i].u=read(),e[i].v=read(),e[i].w=read();
    sort(e+1,e+m+1);
    for(int i=1;i<=K;i++)
        ne[i].u=read(),ne[i].v=read();
    for(int i=1;i<=n;i++)p[i]=read();
    for(int i=1;i<=n;i++)fa[i]=fa2[i]=i;
    for(int i=1;i<=K;i++)
        fa[find(ne[i].u)]=find(ne[i].v);
    for(int i=1;i<=m;i++)
    {
        int u=e[i].u,v=e[i].v;
        if(find(u)!=find(v))
        {
            fa[find(u)]=fa[find(v)];
            fa2[find2(u)]=fa2[find2(v)];
        }
    }
}

void getmp()
{
    st=find2(1);
    for(int i=1;i<=n;i++)
    {
        val[find2(i)]+=p[i];
        if(find2(i)==i)po[++po[0]]=i;
    }
    for(int i=1;i<=K;i++)
        ne[i].u=find2(ne[i].u),ne[i].v=find2(ne[i].v);
    for(int i=1;i<=m;i++)
        e[i].u=find2(e[i].u),e[i].v=find2(e[i].v);
    for(int i=1;i<=m;i++)
    {
        int p=find2(e[i].u),q=find2(e[i].v);
        if(p!=q)mark[i]=1,fa2[p]=q;
    }
    for(int i=1;i<=m;i++)
        if(mark[i])q[++top]=e[i];
}

int main()
{
    freopen("BZOJ3206.in","r",stdin);
    freopen("BZOJ3206.out","w",stdout);

    init();
    getmp();
    dfs(1);
    printf("%lld\n",ans);

    return 0;
}

【总结】
其实缩图也是一种很好的思想qwq。
然后数据太坑了!O2太坑了!

你可能感兴趣的:(图论-MST)