01字典树专题

以前一直以为字典树没有多少用,但是最近一直碰到(难道是以前刷题太少的原因么),其中有一类问题叫做01字典树问题,它是用来解决xor的有力武器,通常是给你一个数组,问你一段连续的异或和最大是多少,正常思路贪心dp啥的都会一头雾水,但是用01字典树就能很快的解决,实现起来也十分方便。
贴一个01字典树的普遍模版

int ch[32*MAX][2];
LL val[32*MAX];
int sz;

void init(){
    mem(ch[0],0);
    sz=1;
}

void inser(LL a){
    int u=0;
    for(int i=32;i>=0;i--){
        int c=((a>>i)&1);
        if(!ch[u][c]){
            mem(ch[sz],0);
            val[sz]=0;
            ch[u][c]=sz++;
        }
        u=ch[u][c];
    }
    val[u]=a;
}
LL query(LL a){
    int u=0;
    for(int i=32;i>=0;i--){
        int c=((a>>i)&1);
        if(ch[u][c^1]) u=ch[u][c^1];
        else u=ch[u][c];
    }
    return val[u];
}

中间的细节可以自己修改,比如有时可能会删除某个数,就需要记录这个节点走了多少次,如果次数为0,就不往下走,数组大小应该开32(64,如果是LL)*数组元素个数。
废话不多说,来看题把。

HDU 4825
题目传送门:http://acm.hdu.edu.cn/showproblem.php?pid=4825
01字典树入门题,每个数都插入字典树中,然后查询即可AC。

HDU 5536
题目传送门:http://acm.hdu.edu.cn/showproblem.php?pid=5536
带有删除的01字典树,先把每个元素插入字典树中,然后 O(n2) 的复杂度枚举两个相加,并且把它们从字典树中暂时去掉,然后查询,然后取最大值
代码:

#include <map>
#include <set>
#include <stack>
#include <queue>
#include <cmath>
#include <string>
#include <vector>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <sstream>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#pragma comment(linker,"/STACK:102400000,102400000")

using namespace std;
#define MAX 1005
#define MAXN 6005
#define maxnode 15
#define sigma_size 30
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define lrt rt<<1
#define rrt rt<<1|1
#define middle int m=(r+l)>>1
#define LL long long
#define ull unsigned long long
#define mem(x,v) memset(x,v,sizeof(x))
#define lowbit(x) (x&-x)
#define pii pair<int,int>
#define bits(a) __builtin_popcount(a)
#define mk make_pair
#define limit 10000

//const int prime = 999983;
const int    INF   = 0x3f3f3f3f;
const LL     INFF  = 0x3f3f;
const double pi    = acos(-1.0);
const double inf   = 1e18;
const double eps   = 1e-8;
const LL    mod    = 1e9+7;
const ull    mx    = 133333331;

/*****************************************************/
inline void RI(int &x) {
      char c;
      while((c=getchar())<'0' || c>'9');
      x=c-'0';
      while((c=getchar())>='0' && c<='9') x=(x<<3)+(x<<1)+c-'0';
 }
/*****************************************************/

int ch[32*MAX][2];
LL val[32*MAX];
int num[32*MAX];
int sz;
LL b[MAX];

void init(){
    mem(ch[0],0);
    sz=1;
}

void inser(LL a){
    int u=0;
    for(int i=32;i>=0;i--){
        int c=((a>>i)&1);
        if(!ch[u][c]){
            mem(ch[sz],0);
            val[sz]=0;
            num[sz]=0;
            ch[u][c]=sz++;
        }
        u=ch[u][c];
        num[u]++;
    }
    val[u]=a;
}
void update(LL a,int d){
    int u=0;
    for(int i=32;i>=0;i--){
        int c=((a>>i)&1);
        u=ch[u][c];
        num[u]+=d;
    }
}
LL query(LL a){
    int u=0;
    for(int i=32;i>=0;i--){
        int c=((a>>i)&1);
        if(ch[u][c^1]&&num[ch[u][c^1]]) u=ch[u][c^1];
        else u=ch[u][c];
    }
    return a^val[u];
}

int main(){
    int t,kase=0;
    cin>>t;
    while(t--){
        int n;
        scanf("%d",&n);
        init();
        for(int i=1;i<=n;i++){
            scanf("%I64d",&b[i]);
            inser(b[i]);
        }
        //kase++;
        //printf("Case #%d:\n",kase);
        LL maxn=0;
        for(int i=1;i<=n;i++){
            for(int j=1;j<=n;j++){
                if(i==j) continue;
                update(b[i],-1);update(b[j],-1);
                maxn=max(maxn,query(b[i]+b[j]));
                update(b[i],1);update(b[j],1);
            }
        }
        cout<<maxn<<endl;
    }
    return 0;
}

bzoj 4260
题目传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=4260
让你在长为n的数组里找两个不相交的连续区间,使得两个区间分别的异或和求和之后最大
反正看到连续区间异或和最大,我只会一个套路,就是01字典树了,先正着来一遍前缀异或和,同时 dp[i] 表示到i为止,前面的区间异或和最大是多少,然后倒着来一遍后缀异或和, ans=max(ans,query(suf[i])+dp[i1]) 即可,并且这题不会爆int,其实好像上面的题也不会爆int,但是这题全开LL会MLE
代码:

#include <map>
#include <set>
#include <stack>
#include <queue>
#include <cmath>
#include <string>
#include <vector>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <sstream>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#pragma comment(linker, "/STACK:102400000,102400000")

using namespace std;
#define MAX 400005
#define MAXN 6005
#define maxnode 15
#define sigma_size 30
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define lrt rt<<1
#define rrt rt<<1|1
#define middle int m=(r+l)>>1
#define LL long long
#define ull unsigned long long
#define mem(x,v) memset(x,v,sizeof(x))
#define lowbit(x) (x&-x)
#define pii pair<int,int>
#define bits(a) __builtin_popcount(a)
#define mk make_pair
#define limit 10000

//const int prime = 999983;
const int    INF   = 0x3f3f3f3f;
const LL     INFF  = 0x3f3f;
const double pi    = acos(-1.0);
//const double inf = 1e18;
const double eps   = 1e-8;
const LL    mod    = 1e9+7;
const ull    mx    = 133333331;

/*****************************************************/
inline void RI(int &x) {
      char c;
      while((c=getchar())<'0' || c>'9');
      x=c-'0';
      while((c=getchar())>='0' && c<='9') x=(x<<3)+(x<<1)+c-'0';
 }
/*****************************************************/

int pre[MAX];
int suf[MAX];
int a[MAX];
int dp[MAX];
int ch[32*MAX][2];
int val[32*MAX];
int sz;

void init(){
    sz=1;
    mem(ch[0],0);
}

void inser(int a){
    int u=0;
    for(int i=31;i>=0;i--){
        int c=((a>>i)&1);
        if(!ch[u][c]){
            mem(ch[sz],0);
            val[sz]=0;
            ch[u][c]=sz++;
        }
        u=ch[u][c];
    }
    val[u]=a;
}
int query(int a){
    int u=0;
    for(int i=31;i>=0;i--){
        int c=((a>>i)&1);
        if(ch[u][c^1]) u=ch[u][c^1];
        else u=ch[u][c];
    }
    return val[u]^a;
}
int main(){
    int n;
    cin>>n;
    init();
    for(int i=1;i<=n;i++) scanf("%d",&a[i]);
    pre[0]=suf[n+1]=0;
    for(int i=1;i<=n;i++) pre[i]=pre[i-1]^a[i];
    for(int i=n;i>0;i--) suf[i]=suf[i+1]^a[i];
    mem(dp,0);
    inser(pre[0]);
    for(int i=1;i<=n;i++){
        dp[i]=max(dp[i-1],query(pre[i]));
        inser(pre[i]);
    }
    init();
    int maxn=0;
    inser(suf[n+1]);
    for(int i=n;i>0;i--){
        maxn=max(maxn,query(suf[i])+dp[i-1]);
        inser(suf[i]);
    }
    cout<<maxn<<endl;
    return 0;
}

poj 3764
题目传送门:http://poj.org/problem?id=3764
这题是树上的最大异或和路径,但是其实也是一样的套路,在dfs的时候,把从根到当前节点的异或和,去01字典树里查询,找到一条路径和当前路径异或和最大,这找到的绝对是两条相连的,因为你往01字典树里扔的,就是从根到当前节点的异或和
代码:

#include <map>
#include <set>
#include <stack>
#include <queue>
#include <cmath>
#include <string>
#include <vector>
#include <cstdio>
#include <cctype>
#include <cstring>
#include <sstream>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#pragma comment(linker, "/STACK:102400000,102400000")

using namespace std;
#define MAX 100005
#define MAXN 6005
#define maxnode 15
#define sigma_size 30
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define lrt rt<<1
#define rrt rt<<1|1
#define middle int m=(r+l)>>1
#define LL long long
#define ull unsigned long long
#define mem(x,v) memset(x,v,sizeof(x))
#define lowbit(x) (x&-x)
#define pii pair<int,int>
#define bits(a) __builtin_popcount(a)
#define mk make_pair
#define limit 10000

//const int prime = 999983;
const int    INF   = 0x3f3f3f3f;
const LL     INFF  = 0x3f3f;
const double pi    = acos(-1.0);
//const double inf = 1e18;
const double eps   = 1e-8;
const LL    mod    = 1e9+7;
const ull    mx    = 133333331;

/*****************************************************/
inline void RI(int &x) {
      char c;
      while((c=getchar())<'0' || c>'9');
      x=c-'0';
      while((c=getchar())>='0' && c<='9') x=(x<<3)+(x<<1)+c-'0';
 }
/*****************************************************/

int ch[32*MAX][2];
int val[32*MAX];
int sz;
struct Edge{
    int v,next,c;
}edge[MAX*2];
int head[MAX];
int tot;
int ans;
void init(){
    sz=1;
    mem(ch[0],0);
    mem(head,-1);
    tot=0;
    ans=0;
}
void add_edge(int a,int b,int c){
    edge[tot]=(Edge){b,head[a],c};
    head[a]=tot++;
}
void inser(int a){
    int u=0;
    for(int i=31;i>=0;i--){
        int c=((a>>i)&1);
        if(!ch[u][c]){
            mem(ch[sz],0);
            val[sz]=0;
            ch[u][c]=sz++;
        }
        u=ch[u][c];
    }
    val[u]=a;
}
int query(int a){
    int u=0;
    for(int i=31;i>=0;i--){
        int c=((a>>i)&1);
        if(ch[u][c^1]) u=ch[u][c^1];
        else u=ch[u][c];
    }
    return val[u]^a;
}

void dfs(int u,int fa,int c){
    inser(c);
    for(int i=head[u];i!=-1;i=edge[i].next){
        int v=edge[i].v;
        if(v==fa) continue;
        ans=max(ans,query(c^edge[i].c));
        dfs(v,u,c^edge[i].c);
    }
}
int main(){
    int n;
    while(cin>>n){
        init();
        for(int i=1;i<n;i++){
            int a,b,c;
            scanf("%d%d%d",&a,&b,&c);
            add_edge(a,b,c);
            add_edge(b,a,c);
        }
        dfs(0,-1,0);
        cout<<ans<<endl;
    }
    return 0;
}

总结:
总的来说如果碰到连续异或和,一般都是01字典树,并且01字典树也很容易,顶多就是加个删除,这个用加标记当前节点用过多少次,删除的时候num--;,恢复的时候num++;即可,十分简单。
xor的还有一种套路就是按位枚举的贪心构造,一般就是几个零散的数的xor和最大或者最小,反正不要求连续的,可以往按位枚举考虑,不过按位枚举也更难。

你可能感兴趣的:(01字典树专题)