有n个点,第i个点的权值为a[i],在第i个点和第j个点之间连边的代价为a[i] and a[j]。问这个图的最大生成树。
n<=100000,a[i]<218 n <= 100000 , a [ i ] < 2 18
这题我们可以用最小生成树的Boruvka算法。具体来说就是一开始对于每个点,找到和该点相连的边权最大的边,然后把这条边加入生成树中。然后对于每个连通块,又找到连接该连通块的边权最大的边加入生成树。这样做每次连通块的复杂度至少减少一半,所以最多做不超过logn次。
现在的问题在于我们如何对于每个点i,找到一个j使得a[i] and a[j]最大。
这个我们可以用字典树。一开始先把每个树都插进去。然后查找的时候,若当前位是1,则走1的子树;若当前位是0,则发现不管走1还是走0都是一样的。于是我们可以在把数插入完后,从下到上把每个节点的1子树复制一遍扔到0子树里面,这样的话就可以每次只走0子树了。
这样做的话不难发现最坏情况每个点都会被祖先遍历一次,于是复杂度是 O(m2m) O ( m 2 m ) 。
我们可以在字典树每个节点维护其子树中连通块编号的最大值和最小值,这样就能判断子树中是否存在不同的连通块。
那么总的时间复杂度就是 O(nmlogn) O ( n m l o g n ) 。
#include
#include
#include
#include
#include
#define mp(x,y) make_pair(x,y)
#define MAX(x,y) x=max(x,y)
#define MIN(x,y) x=min(x,y)
using namespace std;
typedef long long LL;
typedef pair<int,int> pi;
const int N=100005;
int n,a[N],sz,f[N],rt,m,bin[20];
LL ans;
pi mx[N];
struct tree{int l,r,mn,mx;}t[N*10];
int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
int find(int x)
{
if (f[x]==x) return x;
else return f[x]=find(f[x]);
}
int newnode()
{
int x=++sz;t[x].l=t[x].r=t[x].mx=0;t[x].mn=n;return x;
}
void ins(int &d,int dep,int v,int id)
{
if (!d) d=newnode();
MIN(t[d].mn,id);MAX(t[d].mx,id);
if (dep<0) return;
if (v&bin[dep]) ins(t[d].r,dep-1,v,id);
else ins(t[d].l,dep-1,v,id);
}
pi query(int d,int dep,int v,int id)
{
if (dep<0) return mp(0,id==t[d].mn?t[d].mx:t[d].mn);
pi ans;
if (v&bin[dep])
if (t[d].r&&(id!=t[t[d].r].mn||id!=t[t[d].r].mx)) ans=query(t[d].r,dep-1,v,id),ans.first+=bin[dep];
else ans=query(t[d].l,dep-1,v,id);
else ans=query(t[d].l,dep-1,v,id);
return ans;
}
int merge(int x,int y)
{
if (!y) return x;
if (!x) x=newnode();
MIN(t[x].mn,t[y].mn);
MAX(t[x].mx,t[y].mx);
t[x].l=merge(t[x].l,t[y].l);
t[x].r=merge(t[x].r,t[y].r);
return x;
}
void dfs(int d,int dep)
{
if (dep<0) return;
if (t[d].l) dfs(t[d].l,dep-1);
if (t[d].r) dfs(t[d].r,dep-1);
t[d].l=merge(t[d].l,t[d].r);
}
void build()
{
sz=0;rt=newnode();
for (int i=1;i<=n;i++) ins(rt,m-1,a[i],find(i));
dfs(rt,m-1);
}
int main()
{
n=read();m=read();int now=n;
bin[0]=1;
for (int i=1;i<=m;i++) bin[i]=bin[i-1]*2;
for (int i=1;i<=n;i++) a[i]=read(),f[i]=i;
while (now>1)
{
build();
for (int i=1;i<=n;i++) mx[i]=mp(-1,0);
for (int i=1;i<=n;i++)
{
pi u=query(rt,m-1,a[i],find(i));
if (u.first>mx[find(i)].first) mx[find(i)]=u;
}
for (int i=1;i<=n;i++)
if (f[i]==i)
{
int x=find(mx[i].second);
if (x!=i) f[i]=x,now--,ans+=mx[i].first;
}
}
printf("%lld",ans);
return 0;
}