2017年ICPC中国大陆区域赛 - Sum of xor sum (线段树维护子段)

题意:

链接:https://vjudge.net/contest/319951#problem/T

给你 n 个数,和 q 次询问,每次询问给你 [ l , r ] ,让输出在区间 [ l , r ] 里每一个子区间亦或和加和。比如: 给你三个数 1、2、3、和一个询问 [ 1 , 3 ] , 需要输出 1 + 2 + 3 + 1 ^ 2 + 2 ^ 3 + 1 ^ 2 ^ 3 。(1 ≤ N,Q ≤ 100000 , 1 ≤ i ≤ N, 0 ≤ A[i] ≤ 1000000 , 1 ≤ L ≤ R ≤ N)

解题思路:

这个题目可以转化为求每一个二进制位的贡献,因为一些数亦或的结果是看对应的第 i 个二进制位的 1 的个数,如果是偶数个那么 第 i 个二进制位亦或出来必然是 0 ,反之为 1 。所以可以用线段树来维护区间 [ l ,r ] 内对于第 i 个二进制位来说 有多少个子区间 1 的个数为奇数个。但是想维护这一个变量还需要其他的辅助变量来维护,还需要维护左连续区间(满足1的个数为偶数个)的个数、左连续区间(满足1的个数为奇数个)的个数、右连续区间(满足1的个数为偶数个)的个数、右连续区间(满足1的个数为奇数个)的个数、当前区间的1的个数是奇数个还是偶数个。代码里注释解释了为什么需要维护这些。又因为 a[i] 是小于 1e6 的,所以21位二进制位就够了,所以我计算的 0 - 20 位二进制位,每一位的贡献,具体细节可以参考代码。

这个题目的线段树的维护方式和线段树维护最大子段和的维护方式很类似,如果做过线段树维护最大子段和的题目,这个题目也就很简单了。相关线段树维护最大子段和的链接:https://blog.csdn.net/ltrbless/article/details/99540854

AC代码:

#include
#define up(i, x, y) for(int i = x; i <= y; i++)
#define down(i, x, y) for(int i = x; i >= y; i--)
#define bug printf("***************************\n")
#define debug(x) cout<<#x"=["< vec;

void push_up(int k)
{
    for(int i = 0; i <= 20; i++)
    {
        t[k].jo[i] = t[lk].jo[i] + t[rk].jo[i]; //维护当前区间的1的个数是奇数个还是偶数个
        t[k].sum[i] = t[lk].sum[i] + t[rk].sum[i] + t[lk].ro[i] * t[rk].lj[i] + t[lk].rj[i] * t[rk].lo[i] ;
        // 当前区间的子区间(满足1的个数为奇数个)的区间个数 = 左孩子的个数+右孩子的个数
        // +左孩子右连续区间(满足1的个数为偶数个)的个数✖右孩子左连续区间(满足1的个数为奇数个)的个数
        // +左孩子右连续区间(满足1的个数为奇数个)的个数✖右孩子左连续区间(满足1的个数为偶数个)的个数
        t[k].lo[i] = t[lk].lo[i] + (t[lk].jo[i] & 1) ? t[rk].lj[i] : t[rk].lo[i];
        // 左连续区间(满足1的个数为偶数个)的个数 = 左孩子左连续区间(满足1的个数为偶数个)的个数 + 
        // 如果左边的1的个数是否为奇数 ? 如果是加上右孩子左连续区间(满足1的个数为奇数个)的个数 :
        // 否则加上右孩子左连续区间(满足1的个数为偶数个)的个数
        t[k].ro[i] = t[rk].ro[i] + (t[rk].jo[i] & 1) ? t[lk].rj[i] : t[lk].ro[i]; // 同上
        t[k].lj[i] = t[lk].lj[i] + (t[lk].jo[i] & 1) ? t[rk].lo[i] : t[rk].lj[i]; // 同上
        t[k].rj[i] = t[rk].rj[i] + (t[rk].jo[i] & 1) ? t[lk].ro[i] : t[lk].rj[i]; // 同上
    }
}

void build(int k, int l, int r)
{
    if(l == r)  // 赋初值,建树
    {
        int tmp = a[l];
        for(int i = 0; i <= 20; i++)
        {
            if((tmp >> i) & 1) 
            {
                t[k].sum[i] = 1;
                t[k].lj[i] = 1;
                t[k].rj[i] = 1;
            }
            else
            {
                t[k].sum[i] = 0;
                t[k].lo[i] = 1;
                t[k].ro[i] = 1;
            }

            t[k].jo[i] = t[k].sum[i];
        }
        return ;
    }
    int mid = (l + r) >> 1;
    build(lk, l, mid);
    build(rk, mid + 1, r);
    push_up(k);
}

void query(int k, int l, int r, int ql, int qr)
{
    if(ql <= l && r <= qr)
    {
        vec.push_back(k); // 找出哪些点需要合并
        return ;
    }
    int mid = (l + r) >> 1;
    if(ql <= mid) query(lk, l, mid, ql, qr);
    if(mid + 1 <= qr) query(rk, mid + 1, r, ql, qr);
}

void join(int x, int y) // 都以 t[0] 为基准,合并到t[0]里
{
    for(int i = 0; i <= 20; i++)
    {
        t[x].jo[i] = t[x].jo[i] + t[y].jo[i];
        t[x].sum[i] = t[x].sum[i] + t[y].sum[i] + t[x].ro[i] * t[y].lj[i] + t[x].rj[i] * t[y].lo[i] ;
        t[x].lo[i] = t[x].lo[i] + (t[x].jo[i] & 1) ? t[y].lj[i] : t[y].lo[i];
        t[x].ro[i] = t[y].ro[i] + (t[y].jo[i] & 1) ? t[x].rj[i] : t[x].ro[i];
        t[x].lj[i] = t[x].lj[i] + (t[x].jo[i] & 1) ? t[y].lo[i] : t[y].lj[i];
        t[x].rj[i] = t[y].rj[i] + (t[y].jo[i] & 1) ? t[x].ro[i] : t[x].rj[i];
    }
}

void init() // 初始化
{
    memset(t[0].jo, 0, sizeof(t[0].jo));
    memset(t[0].lo, 0, sizeof(t[0].lo));
    memset(t[0].ro, 0, sizeof(t[0].ro));
    memset(t[0].lj, 0, sizeof(t[0].lj));
    memset(t[0].rj, 0, sizeof(t[0].rj));
    memset(t[0].sum, 0, sizeof(t[0].sum));
}

int main()
{
    int T; scanf("%d", &T); while(T--)
    {
        scanf("%d %d", &n, &q);
        for(int i = 1; i <= n; i++)
        {
            scanf("%d", &a[i]);
        }
        build(1, 1, n);
        while(q--)
        {
            scanf("%d %d", &l, &r);
            ll ans = 0; vec.clear(); int v;
            init();
            query(1, 1, n, l, r); // 查找线段树中哪些点需要合并
            v = vec[0]; // 第一个点取出来(也是最左边的点)
            if(vec.size() == 1)  // 都以 t[0] 为基准,合并到t[0]里
            {
                t[0] = t[v];
            }
            else
            {                
                t[0] = t[v];
                for(int i = 1; i < vec.size(); i++)
                {
                    join(0, vec[i]); // 都以 t[0] 为基准,合并到t[0]里
                }
            }
            for(int i = 0; i <= 20; i++)
            {
                ans = (ans + (t[0].sum[i] * 1LL * (1 << i)) % mod) % mod; // 根据每一位的贡献,算出总的贡献大小
            }
            printf("%lld\n", ans);

        }
    }
}

 

 

 

 

 

你可能感兴趣的:(ACM,数据结构,ICPC中国大陆区域赛)