[bzoj5294][Bjoi2018]二进制【线段树】

【题目链接】
  https://www.lydsy.com/JudgeOnline/problem.php?id=5294
【题解】
  考虑一个二进制数什么时候重排后模3不余0,只有两种情况。
  1.只存在一个1。
  2.存在奇数个1且0的数量不大于1。
  那么我们考虑从总方案中减去不合法的方案,在线段树中每个节点维护左(右)端第一个与第二个0/1。分情况讨论一下即可。
  时间复杂度 O(NlogN) O ( N ∗ l o g N )
【代码】

# include 
# define    ll      long long
# define    inf     0x3f3f3f3f
# define    N       100010
using namespace std;
int read(){
    int tmp = 0, fh = 1; char ch = getchar();
    while (ch < '0' || ch > '9') {if (ch == '-') fh = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9'){tmp = tmp * 10 + ch - '0'; ch = getchar(); }
    return tmp * fh;
}
struct Tree{
    int pl, pr, l00, l01, r00, r01, l10, l11, r10, r11;
    ll sum;
}T[N * 3];
int L, R, place, n, q[6], h[N], rt, m;
bool cmp0(int x, int y){return x < y;}
bool cmp1(int x, int y){return x > y;}
void reget(int p, int mid, int ql, int qr){
    ll sum = T[T[p].pl].sum + T[T[p].pr].sum;
    ll l0 = mid - max(T[T[p].pl].r00 + 1, ql) + 1, r0 = min(T[T[p].pr].l00 - 1, qr) - (mid + 1) + 1, l1, r1;
    sum += (l0 / 2) * ((r0 + 1) / 2) + ((l0 + 1) / 2) * (r0 / 2);
    l0 = mid - max(T[T[p].pl].r00 + 1, ql) + 1, r0 = min(T[T[p].pr].l00 - 1, qr) - (mid + 1) + 1, 
    l1 = mid - max(T[T[p].pl].r01 + 1, ql) + 1, r1 = min(T[T[p].pr].l01 - 1, qr) - (mid + 1) + 1;
    ll nowl = l1 - l0, nowr = r0;
    if (l0 == 0)
        sum += max(((nowl + 1) / 2) * ((nowr + 1) / 2) + (nowl / 2) * (nowr / 2) - 1, 0ll);
        else if (l0 % 2 == 0)
            sum += ((nowl + 1) / 2) * ((nowr + 1) / 2) + (nowl / 2) * (nowr / 2);
            else sum += (nowl / 2) * ((nowr + 1) / 2) + ((nowl + 1) / 2) * (nowr / 2);
    nowl = l0, nowr = r1 - r0;
    if (r0 == 0)
        sum += max(((nowl + 1) / 2) * ((nowr + 1) / 2) + (nowl / 2) * (nowr / 2) - 1, 0ll);
        else if (r0 % 2 == 0)
            sum += ((nowl + 1) / 2) * ((nowr + 1) / 2) + (nowl / 2) * (nowr / 2);
            else sum += (nowl / 2) * ((nowr + 1) / 2) + ((nowl + 1) / 2) * (nowr / 2);
    l0 = mid - max(ql - 1, T[T[p].pl].r10) + 1, r0 = min(T[T[p].pr].l10, qr + 1) - (mid + 1) + 1;
    l1 = mid - max(ql - 1, T[T[p].pl].r11) + 1, r1 = min(T[T[p].pr].l11, qr + 1) - (mid + 1) + 1;
    sum += (l1 - l0) * (r0 - 1) + (r1 - r0) * (l0 - 1);
    T[p].sum = sum;
    q[1] = T[T[p].pl].l00, q[2] = T[T[p].pl].l01,  q[3] = T[T[p].pr].l00, q[4] = T[T[p].pr].l01;
    sort(q + 1, q + 4 + 1, cmp0);
    T[p].l00 = q[1], T[p].l01 = q[2];
    q[1] = T[T[p].pl].r00, q[2] = T[T[p].pl].r01,  q[3] = T[T[p].pr].r00, q[4] = T[T[p].pr].r01;
    sort(q + 1, q + 4 + 1, cmp1);
    T[p].r00 = q[1], T[p].r01 = q[2];
    q[1] = T[T[p].pl].l10, q[2] = T[T[p].pl].l11,  q[3] = T[T[p].pr].l10, q[4] = T[T[p].pr].l11;
    sort(q + 1, q + 4 + 1, cmp0);
    T[p].l10 = q[1], T[p].l11 = q[2];
    q[1] = T[T[p].pl].r10, q[2] = T[T[p].pl].r11,  q[3] = T[T[p].pr].r10, q[4] = T[T[p].pr].r11;
    sort(q + 1, q + 4 + 1, cmp1);
    T[p].r10 = q[1], T[p].r11 = q[2];
}
void build(int &p, int l, int r){
    p = ++place;
    if (l != r){
        int mid = (l + r) / 2;
        build(T[p].pl, l, mid);
        build(T[p].pr, mid + 1, r);
        reget(p, mid, l, r);
    }
    else{
        if (h[l] == 1){
            T[p].l10 = T[p].r10 = l;
            T[p].l11 = inf, T[p].r11 = 0; 
            T[p].l00 = inf, T[p].r00 = 0;
            T[p].l01 = inf, T[p].r01 = 0;
            T[p].sum = 1;
        }
        else {
            T[p].l00 = T[p].r00 = l; 
            T[p].l11 = inf, T[p].r11 = 0; 
            T[p].l10 = inf, T[p].r10 = 0;
            T[p].l01 = inf, T[p].r01 = 0;
            T[p].sum = 0;
        }
    }
}
ll query(int p, int ql, int qr, int l, int r){
    if (ql == l && qr == r) return T[p].sum;
    int mid = (l + r) / 2;
    if (mid >= qr) return query(T[p].pl, ql, qr, l, mid);
        else if (mid < ql) return query(T[p].pr, ql, qr, mid + 1, r);
            else {
                ll sum = query(T[p].pl, ql, mid, l, mid) + query(T[p].pr, mid + 1, qr, mid + 1, r);
                ll l0 = mid - max(T[T[p].pl].r00 + 1, ql) + 1, r0 = min(T[T[p].pr].l00 - 1, qr) - (mid + 1) + 1, l1, r1;
                sum += (l0 / 2) * ((r0 + 1) / 2) + ((l0 + 1) / 2) * (r0 / 2);
                l0 = mid - max(T[T[p].pl].r00 + 1, ql) + 1, r0 = min(T[T[p].pr].l00 - 1, qr) - (mid + 1) + 1, 
                l1 = mid - max(T[T[p].pl].r01 + 1, ql) + 1, r1 = min(T[T[p].pr].l01 - 1, qr) - (mid + 1) + 1;
                ll nowl = l1 - l0, nowr = r0;
                if (l0 == 0)
                    sum += max(((nowl + 1) / 2) * ((nowr + 1) / 2) + (nowl / 2) * (nowr / 2) - 1, 0ll);
                    else if (l0 % 2 == 0)
                        sum += ((nowl + 1) / 2) * ((nowr + 1) / 2) + (nowl / 2) * (nowr / 2);
                        else sum += (nowl / 2) * ((nowr + 1) / 2) + ((nowl + 1) / 2) * (nowr / 2);
                nowl = l0, nowr = r1 - r0;
                if (r0 == 0)
                    sum += max(((nowl + 1) / 2) * ((nowr + 1) / 2) + (nowl / 2) * (nowr / 2) - 1, 0ll);
                    else if (r0 % 2 == 0)
                        sum += ((nowl + 1) / 2) * ((nowr + 1) / 2) + (nowl / 2) * (nowr / 2);
                        else sum += (nowl / 2) * ((nowr + 1) / 2) + ((nowl + 1) / 2) * (nowr / 2);
                l0 = mid - max(ql - 1, T[T[p].pl].r10) + 1, r0 = min(T[T[p].pr].l10, qr + 1) - (mid + 1) + 1;
                l1 = mid - max(ql - 1, T[T[p].pl].r11) + 1, r1 = min(T[T[p].pr].l11, qr + 1) - (mid + 1) + 1;
                sum += (l1 - l0) * (r0 - 1) + (r1 - r0) * (l0 - 1);
                return sum;
            }
}
void modify(int p, int x, int l, int r){
    if (l == r){
        swap(T[p].l00, T[p].l10);
        swap(T[p].r00, T[p].r10);
        T[p].sum = 1 - T[p].sum;
        return;
    }
    int mid = (l + r) / 2;
    if (mid >= x) modify(T[p].pl, x, l, mid);
        else modify(T[p].pr, x, mid + 1, r);
    reget(p, mid, l, r);
}
int main(){
    n = read();
    for (int i = 1; i <= n; i++)
        h[i] = read();
    L = 1, R = n;
    build(rt, 1, n);
    m = read();
    for (int i = 1; i <= m; i++){
        int op = read();
        if (op == 2){
            int l = read(), r = read();
            ll ans = 1ll * (r - l + 2) * (r - l + 1) / 2; 
            ans -= query(rt, l, r, L, R);
            printf("%lld\n", ans);
        }
        else {
            int x = read();
            modify(rt, x, L, R);
        }
    }
    return 0;
}

你可能感兴趣的:(【线段树】)