在图论中,无向图 G G G 的生成树,是具有 G G G 的全部顶点,但边数最少的连通子图。
而 G G G 的最小生成树,即为 G G G 中所有的生成树中,所有边的边权和最小的一棵生成树。
给定一个数组 x 1 , x 2 , ⋯ , x n x_1,x_2,\cdots, x_n x1,x2,⋯,xn。构造一张 n n n 个点的完全图。对于任意 1 ≤ i < j ≤ n 1 \le i < j \le n 1≤i<j≤n,图中有一条边权为 x j − x i x_j - x_i xj−xi 的无向边。
你想要求出这张无向图的最小生成树的边权之和。
求最小生成树的常用算法有 Kruskal,Prim \text{Kruskal,Prim} Kruskal,Prim。 Kruskal \text{Kruskal} Kruskal 的时间复杂度是与边数相关的,而此题边数很多,所以考虑 Prim \text{Prim} Prim 的思想。
维护两个点集:已匹配点集个未匹配点集。每次找出两个点集的最短边,将边的未匹配点集点加入已匹配点集。
对于此题,考虑使用线段树维护两个点集的最短边权值。
m a x 0 r t max0_{rt} max0rt 表示 r t rt rt 所代表的区间中,未匹配的点中 x x x 值最大的。
m a x 1 r t max1_{rt} max1rt 表示 r t rt rt 所代表的区间中,已匹配的点中 x x x 值最大的。
m i n 0 r t min0_{rt} min0rt 表示 r t rt rt 所代表的区间中,未匹配的点中 x x x 值最小的。
m i n 1 r t min1_{rt} min1rt 表示 r t rt rt 所代表的区间中,已匹配的点中 x x x 值最小的。
a n s r t ans_{rt} ansrt 表示 r t rt rt 所代表的区间中的两个点之间的最小边权,其中两个点不能是一个点集。
在 p u s h u p pushup pushup 过程中, m a x 0 , m a x 1 , m i n 0 , m i n 1 max0,max1,min0,min1 max0,max1,min0,min1 只需简单地从儿子区间更新。
a n s ans ans 先从儿子区间更新,再考虑两个分居两个区间的点之间的边权,容易想到要边权最小,一定是选右边最小的和左边最大的(注意两点要不同点集), a n s ans ans 用其更新即可。
每次找到最短边后,将边的未匹配点集点 x x x 加入已匹配点集,实现可以把 x x x 在线段树单点交换 m a x 0 , m a x 1 max0,max1 max0,max1 和 m i n 0 , m i n 1 min0,min1 min0,min1。(相当于将这个点的属性有未匹配变为匹配了)
时间复杂度 O ( n log n ) O(n\log n) O(nlogn)
具体实现参见代码
#include
using namespace std;
const int N=3e5+1,INF=1e9;
int n,a[N];
long long Ans;
struct node
{
int v,id;
node(){}
node(int a,int b){v=a,id=b;}
bool operator<(const node &a)const{
return v<a.v;
}
}max0[N<<2],max1[N<<2],min0[N<<2],min1[N<<2],ans[N<<2];
void pushup(int rt)
{
max0[rt]=max(max0[rt<<1],max0[rt<<1|1]);
max1[rt]=max(max1[rt<<1],max1[rt<<1|1]);
min0[rt]=min(min0[rt<<1],min0[rt<<1|1]);
min1[rt]=min(min1[rt<<1],min1[rt<<1|1]);
ans[rt]=min({ans[rt<<1],ans[rt<<1|1],min(node(min1[rt<<1|1].v-max0[rt<<1].v,max0[rt<<1].id),
node(min0[rt<<1|1].v-max1[rt<<1].v,min0[rt<<1|1].id))});
}
void build(int rt,int l,int r)
{
if(l==r){
max0[rt]=min0[rt]=node(a[l],l);
max1[rt]=node(-INF,-1);
min1[rt]=node(INF,-1);
ans[rt]=node(INF,-1);
return;
}
int mid=l+r>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
}
void update(int rt,int l,int r,int x)
{
if(l==r){
swap(max0[rt],max1[rt]),swap(min0[rt],min1[rt]);
return;
}
int mid=l+r>>1;
if(x<=mid) update(rt<<1,l,mid,x);
else update(rt<<1|1,mid+1,r,x);
pushup(rt);
}
int main()
{
freopen("mst.in","r",stdin);
freopen("mst.out","w",stdout);
cin.tie(0)->sync_with_stdio(0);
cin>>n;
for(int i=1;i<=n;i++) cin>>a[i];
build(1,1,n);
update(1,1,n,1);
for(int i=1;i<n;i++){
Ans+=ans[1].v;
update(1,1,n,ans[1].id);
}
cout<<Ans;
}