【题目】
原题地址
题目大意:给定一幅无向图, 每条边有一个边权,一个人走过边时要交的税就是这条边边权。每个点有 pi p i 个人,他们要去到1号点搞活动。现在你拥有其中 k k 条边的控制权,即你可以任意改变这些边边权。在你决定边权后,你可以决定这幅图的一棵最小生成树,每个点上的人只能走这些边到1号点。求你最多能得到多少的税款。
【题目分析】
又是一道很巧妙的题目,关键点在于看清楚题目。
一眼看过去很不可做,不过k很小,我们可以考虑像动态维护MST的时候一样,将图缩小。
【解题思路】
首先注意到每条边的边权不同,因此在没有新边时,MST是固定的。
又发现新边很少,回想起动态维护MST时我们的做法,我们可以考虑缩图。
这样将新边设为 −INF − I N F ,我们做MST后,在MST上的原图边一定会出现在我们最后的选择中。
接着将所有新边去掉进行缩点,图变成了至多 k+1 k + 1 个联通块,同时我们可以缩边,这样最多有400条边。
用这些边再做一次MST,就能将边数缩成 k k 。
接下来我们要做的就是确定新边的边权,但是发现我们不一定选新边建MST会最优,所以我们还要枚举一个边集。
考虑枚举边集后我们结合原图中的边作MST,那么如果我们新加进一条边时,如果形成了一个环,则这条边会对环上所有未确定边权的边产生一个限制。
这样如果我们暴力更新环上所有边的话的话,后半部分的复杂度是 O(2k∗k2) O ( 2 k ∗ k 2 ) 的,算起来过不了,实际上能过。
但实际上这个做法还可以得到优化,我们发现要优化关键在于减少边权的更新。
我们可以对所有原图边按边权从小到大进行编号,这样如果我们知道一条新边会被哪些原图边算入环,我们就可以O1得到限制(lowbit一下)。
然后对于每条原图不在MST上的边 (u,v) ( u , v ) ,我们在 u u 和 v v 上都打一个编号标记,
对于每个节点,我们知道它有一些标记,那么它的父边就会被这些标记代表的边贡献到。
回溯时,我们将这个节点的标记异或上所有它儿子节点的标记就可以了(可以画图验证)
这样子后半部分的复杂度就是 O(2k∗k2) O ( 2 k ∗ k 2 ) ,可以完美通过此题。
对了前面缩图,MST之类的复杂度应该就是 O(mlogm) O ( m l o g m ) 之类的吧,不会超过后面的复杂度。
具体细节可以看代码。
【一些奇奇怪怪的神秘错误】
我是和机房的sc大佬一起写这道题目的。但是我的程序可能因为某种不知名的程序规范,开启O2后答案会错误,但是不开O2是正确的。
sc大佬的程序十分正常,但是过不了BZOJ有边权相同的数据。
这时候问题来了,我的程序把BZOJ的数据过了。。。然而BZOJ应该是自带O2的,这就十分玄学了。
交到UOJ上我是过不了的,sc是AC的。
当然洛谷上数据貌似有错误,因此我们都过不了。
下面的代码仅供参考借鉴实现,暴力计算边权的程序是从网上扒下来的。
【参考代码】
我的代码:(BZOJ-AC)
#include
using namespace std;
typedef long long LL;
const int INF=1e9;
const int P=25;
const int N=1e5+10;
const int M=6e5+10;
int n,m,tot,K,cnt,ps,ns,rway;
int tar[P],fa[N],head[N],bl[N];
int mp[P][P],bo[P][P],Log2[4194306];
vector<int>es[P];
LL tmpans,ans;
LL peo[P],p[N];
struct Tway
{
int v,nex,w;
};
Tway e[M];
struct Tnode
{
int x,y,w,op;
};
Tnode a[M],b[P],c[P],rem[P];
void add(int u,int v,int w)
{
e[++tot]=(Tway){v,head[u],w};head[u]=tot;
e[++tot]=(Tway){u,head[v],w};head[v]=tot;
}
bool cmp(Tnode A,Tnode B)
{
return A.wbool cmpop(Tnode A,Tnode B)
{
return A.op>B.op;
}
int read()
{
int ret=0,f=1;char c=getchar();
while(!isdigit(c)){if(c=='-')f=0;c=getchar();}
while(isdigit(c)){ret=(ret<<1)+(ret<<3)+(c^48);c=getchar();}
return f?ret:-ret;
}
int findf(int x)
{
return fa[x]==x?x:fa[x]=findf(fa[x]);
}
void init()
{
Log2[1]=0;
for(int i=2;i<4194305;++i)
Log2[i]=Log2[i>>1]+1;
n=read();m=read();K=read();
for(int i=1;i<=m;++i)
a[i].x=read(),a[i].y=read(),a[i].w=read();
for(int i=1;i<=K;++i)
b[i].x=read(),b[i].y=read();
for(int i=1;i<=n;++i)
p[i]=read();
for(int i=1;i<=n;++i)
fa[i]=i;
sort(a+1,a+m+1,cmp);
for(int i=1;i<=K;++i)
{
int fx=findf(b[i].x),fy=findf(b[i].y);
if(fx!=fy)
fa[fx]=fy;
}
for(int i=1;i<=m;++i)
{
int fx=findf(a[i].x),fy=findf(a[i].y);
if(fx!=fy)
fa[fx]=fy,a[i].op=1;
}
}
void dfs(int x)
{
peo[cnt]+=p[x];bl[x]=cnt;
for(int i=head[x];i;i=e[i].nex)
{
int v=e[i].v;
if(!bl[v])
dfs(v);
}
}
void getmp()
{
sort(a+1,a+m+1,cmpop);
for(int i=1;i<=m && a[i].op;++i)
add(a[i].x,a[i].y,0);
for(int i=1;i<=n;++i)
if(!bl[i])
++cnt,dfs(i);
for(int i=1;i<=cnt;++i)
for(int j=1;j<=cnt;++j)
mp[i][j]=mp[j][i]=INF;
for(int i=1;i<=m;++i)
{
int x=bl[a[i].x],y=bl[a[i].y];
if(x==y)
continue;
mp[x][y]=mp[y][x]=min(mp[x][y],a[i].w);
}
tot=0;
for(int i=0;i0;
for(int i=1;ifor(int j=i+1;j<=cnt;++j)
if(mp[i][j]!=INF)
++ps,c[ps].x=i,c[ps].y=j,c[ps].w=mp[i][j];
for(int i=0;i1,c+ps+1,cmp);
for(int i=1;i<=ps;++i)
{
int fx=findf(c[i].x),fy=findf(c[i].y);
if(fx!=fy)
c[++ns]=c[i],fa[fx]=fy;
}
}
int lowbit(int x)
{
return x&(-x);
}
LL dfs2(int x,int f)
{
LL sum=peo[x];
for(int i=0;i<(int)es[x].size();++i)
{
int v=es[x][i];
if(v!=f)
sum+=dfs2(v,x);
}
if(x!=bl[1])
{
tar[f]^=tar[x];
if(bo[x][f])
{
int p=Log2[lowbit(tar[x])];
tmpans+=1ll*sum*c[p+1].w;
}
}
return sum;
}
void solve()
{
for(int i=1;i<=K;++i)
b[i].x=bl[b[i].x],b[i].y=bl[b[i].y];
for(int sta=1;sta<(1<for(int i=0;i<=cnt+1;++i)
fa[i]=i,es[i].clear(),tar[i]=0;
rway=0;
bool flag=true;
for(int i=1;i<=K;++i)
{
if(!(sta&(1<<(i-1))))
continue;
int fx=findf(b[i].x),fy=findf(b[i].y);
bo[b[i].x][b[i].y]=bo[b[i].y][b[i].x]=i;rem[++rway]=b[i];
if(fx!=fy)
fa[fx]=fy,es[b[i].x].push_back(b[i].y),es[b[i].y].push_back(b[i].x);
else
flag=false;
}
if(!flag)
{
for(int i=1;i<=rway;++i)
bo[rem[i].x][rem[i].y]=bo[rem[i].y][rem[i].x]=0;
continue;
}
for(int i=1;i<=ns;++i)
{
int fx=findf(c[i].x),fy=findf(c[i].y);
if(fx!=fy)
{
fa[fx]=fy;
es[c[i].x].push_back(c[i].y),es[c[i].y].push_back(c[i].x);
}
else
tar[c[i].x]|=(1<<(i-1)),tar[c[i].y]|=(1<<(i-1));
}
tmpans=0;
dfs2(bl[1],bl[1]);
ans=max(ans,tmpans);
for(int i=1;i<=rway;++i)
bo[rem[i].x][rem[i].y]=bo[rem[i].y][rem[i].x]=0;
}
printf("%lld\n",ans);
}
int main()
{
freopen("BZOJ3206.in","r",stdin);
freopen("BZOJ3206.out","w",stdout);
init();
getmp();
solve();
return 0;
}
sc的代码:(UOJ-AC)
#include
using namespace std;
typedef long long ll;
const int maxn = 6e+5 + 10;
const int maxk = 20 + 5;
const int INF = 0x3fffffff;
int n,m,k,ta[maxn];ll a[maxk];
int lab[maxn],labn;
int llg2[maxn];
struct edge{
int u,v,d;
}e[maxn],ek[maxk];
struct DSU{
int f[maxn];
void init(int x){memset(f,0,x<<2);}
int fa(int x)
{
if(!f[x]) return x;
return f[x]=fa(f[x]);
}
bool merge(int x,int y)
{
x=fa(x);y=fa(y);
if(x==y) return false;
if(xreturn true;
}
void relable()
{
int i;
labn=1;memset(a,0,sizeof(a));
for(i=1;i<=n;i++) if(!f[i]) lab[i]=labn++;
for(i=1;i<=n;i++) lab[i]=lab[fa(i)],a[lab[i]]+=ta[i];//,printf("%d in lab:%d\n",i,lab[i]);
}
}L,K;
inline int read()
{
int x=0;char c=getchar();
for(;c<'0'||c>'9';c=getchar());
for(;c>='0'&&c<='9';c=getchar()) x=(x<<3)+(x<<1)+(c^48);
return x;
}
bool cmpd(edge a,edge b){return a.dvoid input()
{
int i;
n=read();m=read();k=read();
for(i=0;ifor(i=m;i1;
}
for(i=1;i<=n;i++) ta[i]=read();
for(i=0;i<20;i++) llg2[1<void build()
{
int i,j;
sort(e,e+m+k,cmpd);
for(i=0;iif(K.merge(e[i].u,e[i].v)) if(~e[i].d)
{
L.merge(e[i].u,e[i].v);
e[i].d=INF;
}
if(!~e[i].d) e[i].d=INF;
}
L.relable();
sort(e,e+m+k,cmpd);K.init(n+1);
j=0;
for(i=0;i//printf("visit: %d-%d d=%d\n",e[i].u,e[i].v,e[i].d);
if(K.merge(lab[e[i].u],lab[e[i].v])) e[j++]=e[i];//,printf("select: %d-%d d=%d\n",e[i].u,e[i].v,e[i].d);
}
m=j;n=labn;
}
int tag[maxk];ll ans,siz[maxk];
vectorint ,bool> > nxt[maxk];
inline void addedge(int u,int v,bool type)
{//printf("addedge %d-%d tp=%d\n",u,v,type);
nxt[u].push_back(make_pair(v,type));
nxt[v].push_back(make_pair(u,type));
}
void dfs(int x,int f)
{
int i,v,j,d;
siz[x]=a[x];
for(i=0;iif((v=nxt[x][i].first)!=f)
{
dfs(v,x);tag[x]^=tag[v];siz[x]+=siz[v];
if(nxt[x][i].second)
{
j=llg2[tag[v]&-tag[v]];//cerr<
d=e[j].d;
ans+=siz[v]*d;
}
}
}
void solve()
{
int S,i;ll fans=0;
//printf("n=%d\n",n);
//for(i=0;i
//for(i=0;i
for(S=0;S<(1<memset (tag,0,sizeof(tag));
for(i=1;ifor (i=0;iif (S>>i&1)
{
if(!K.merge(lab[ek[i].u],lab[ek[i].v])) goto Fail;
addedge(lab[ek[i].u],lab[ek[i].v],true);
}
for(i=0;iif(!K.merge(lab[e[i].u],lab[e[i].v])) tag[lab[e[i].u]]|=1<1<else addedge(lab[e[i].u],lab[e[i].v],false);
}
ans=0;
dfs(1,-1);//printf("S=%d ans=%lld\n",S,ans);
fans=max(fans,ans);
Fail:;
}
cout<int main()
{
freopen("lg3639.in","r",stdin);
freopen("lg3639.out","w",stdout);
input();
build();
solve();
return 0;
}
网上的暴力修改代码(BZOJ-AC、UOJ-AC):
#include
using namespace std;
typedef long long LL;
const int INF=1e9;
const int N=1e5+10;
const int M=3e5+10;
const int P=25;
int n,m,K,top,cnt,st;
int fa[N],fa2[N],p[N];
int po[P],dep[N],last[N],mn[N];
LL ans,val[N],sum[N];
bool mark[M];
LL read()
{
LL x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
struct Tnode
{
int u,v,w;
}e[M],ne[P],q[M];
bool operator<(Tnode a,Tnode b)
{
return a.wint to,next;
}ed[P<<1];
void insert(int u,int v)
{
ed[++cnt]=(Tway){v,last[u]};last[u]=cnt;
ed[++cnt]=(Tway){u,last[v]};last[v]=cnt;
}
int find(int x)
{
return x==fa[x]?x:fa[x]=find(fa[x]);
}
int find2(int x)
{
return x==fa2[x]?x:fa2[x]=find2(fa2[x]);
}
void dp(int x)
{
sum[x]=val[x];
for(int i=last[x];i;i=ed[i].next)
if(ed[i].to!=fa2[x])
{
dep[ed[i].to]=dep[x]+1;
fa2[ed[i].to]=x;
dp(ed[i].to);
sum[x]+=sum[ed[i].to];
}
}
void solve()
{
cnt=0;
for(int i=1;i<=K+1;i++)
{
int p=po[i];
last[p]=fa2[p]=0;
fa[p]=p;mn[p]=INF;
}
for(int i=1;i<=K;i++)
if(mark[i])
{
int x=find(ne[i].u),y=find(ne[i].v);
if(x==y)return;
fa[x]=y;
insert(ne[i].u,ne[i].v);
}
for(int i=1;i<=K;i++)
{
int x=find(q[i].u),y=find(q[i].v);
if(x!=y)fa[x]=y,insert(q[i].u,q[i].v);
}
dp(st);
for(int i=1;i<=K;i++)
{
int u=q[i].u,v=q[i].v;
if(dep[u]>dep[v])swap(u,v);
while(dep[v]!=dep[u])mn[v]=min(mn[v],q[i].w),v=fa2[v];
while(u!=v)
{
mn[v]=min(mn[v],q[i].w);
mn[u]=min(mn[u],q[i].w);
u=fa2[u];v=fa2[v];
}
}
LL inc=0;
for(int i=1;i<=K;i++)
if(mark[i])
{
int u=ne[i].u,v=ne[i].v;
if(dep[u]>dep[v])swap(u,v);
inc+=mn[v]*sum[v];
}
ans=max(inc,ans);
}
void dfs(int x)
{
if(x==K+1)
{
solve();
return;
}
mark[x]=0;dfs(x+1);
mark[x]=1;dfs(x+1);
}
void init()
{
n=read();m=read();K=read();
for(int i=1;i<=m;i++)
e[i].u=read(),e[i].v=read(),e[i].w=read();
sort(e+1,e+m+1);
for(int i=1;i<=K;i++)
ne[i].u=read(),ne[i].v=read();
for(int i=1;i<=n;i++)p[i]=read();
for(int i=1;i<=n;i++)fa[i]=fa2[i]=i;
for(int i=1;i<=K;i++)
fa[find(ne[i].u)]=find(ne[i].v);
for(int i=1;i<=m;i++)
{
int u=e[i].u,v=e[i].v;
if(find(u)!=find(v))
{
fa[find(u)]=fa[find(v)];
fa2[find2(u)]=fa2[find2(v)];
}
}
}
void getmp()
{
st=find2(1);
for(int i=1;i<=n;i++)
{
val[find2(i)]+=p[i];
if(find2(i)==i)po[++po[0]]=i;
}
for(int i=1;i<=K;i++)
ne[i].u=find2(ne[i].u),ne[i].v=find2(ne[i].v);
for(int i=1;i<=m;i++)
e[i].u=find2(e[i].u),e[i].v=find2(e[i].v);
for(int i=1;i<=m;i++)
{
int p=find2(e[i].u),q=find2(e[i].v);
if(p!=q)mark[i]=1,fa2[p]=q;
}
for(int i=1;i<=m;i++)
if(mark[i])q[++top]=e[i];
}
int main()
{
freopen("BZOJ3206.in","r",stdin);
freopen("BZOJ3206.out","w",stdout);
init();
getmp();
dfs(1);
printf("%lld\n",ans);
return 0;
}
【总结】
其实缩图也是一种很好的思想qwq。
然后数据太坑了!O2太坑了!