[HDU 4747 Mex] Mex函数 线段树

·题目

http://acm.hdu.edu.cn/showproblem.php?pid=4747

·分析

先预处理出mex[1][i],记录每个位置下次出现的位置next[i],这里大于n的数都可以看作n+1

然后从前往后扫描,对于同一个i,容易发现mex[i][j]是随j递增的

于是找到i和next[i]-1之间第一个mex[i][j]比a[i]大的位置x

那么mex[i+1][x]到mex[i+1][next[i]-1]之间的值一定都为a[i],因为这一段a[i]空缺

于是转化为区间求和和区间赋值

·代码

/**************************************************
 *        Problem:  HDU 4747
 *         Author:  clavichord93
 *          State:  Accepted
 **************************************************/

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
using namespace std;
typedef long long ll;

const int MAX_N = 200005;

int n;
int a[MAX_N];
int mex[MAX_N];
int last[MAX_N];
int next[MAX_N];
ll sum[MAX_N << 2];
int maxv[MAX_N << 2];
int tag[MAX_N << 2];
bool vis[MAX_N];

#define lch(t) (t << 1)
#define rch(t) (t << 1 | 1)

void makeTree(int t, int l, int r) {
    if (l == r) {
        sum[t] = mex[l];
        maxv[t] = mex[l];
        tag[t] = -1;
    }
    else {
        int mid = (l + r) >> 1;
        makeTree(lch(t), l, mid);
        makeTree(rch(t), mid + 1, r);
        sum[t] = sum[lch(t)] + sum[rch(t)];
        maxv[t] = max(maxv[lch(t)], maxv[rch(t)]);
        tag[t] = -1;
    }
}

void pushdown(int t, int l, int r) {
    if (tag[t] != -1) {
        int lt = lch(t);
        int rt = rch(t);
        int mid = (l + r) >> 1;

        sum[lt] = tag[t] * (mid - l + 1);
        maxv[lt] = tag[t];
        tag[lt] = tag[t];

        sum[rt] = tag[t] * (r - mid);
        maxv[rt] = tag[t];
        tag[rt] = tag[t];

        tag[t] = -1;
    }
}

ll getSum(int t, int l, int r, int x, int y) {
    if (x <= l && r <= y) {
        return sum[t];
    }
    else {
        pushdown(t, l, r);
        int mid = (l + r) >> 1;
        ll sum = 0;
        if (x <= mid) {
            sum += getSum(lch(t), l, mid, x, y);
        }
        if (y > mid) {
            sum += getSum(rch(t), mid + 1, r, x, y);
        }
        return sum;
    }
}

int getPos(int t, int l, int r, int x, int y, int val) {
    if (l == r) {
        return l;
    }
    else {
        pushdown(t, l, r);
        int mid = (l + r) >> 1;
        if (x <= mid && maxv[lch(t)] > val) {
            return getPos(lch(t), l, mid, x, y, val);
        }
        if (y > mid && maxv[rch(t)] > val) {
            return getPos(rch(t), mid +1, r, x, y, val);
        }
        return -1;
    }
}

void change(int t, int l, int r, int x, int y, int val) {
    if (x <= l && r <= y) {
        tag[t] = val;
        sum[t] = (ll)(r - l + 1) * val;
        maxv[t] = val;
    }
    else {
        pushdown(t, l, r);
        int mid = (l + r) >> 1;
        if (x <= mid) {
            change(lch(t), l, mid, x, y, val);
        }
        if (y > mid) {
            change(rch(t), mid + 1, r, x, y, val);
        }
        sum[t] = sum[lch(t)] + sum[rch(t)];
        maxv[t] = max(maxv[lch(t)], maxv[rch(t)]);
    }
}

#undef lch
#undef rch

int main() {
    #ifdef LOCAL_JUDGE
    freopen("in.txt", "r", stdin);
    freopen("out.txt", "w", stdout);
    #endif
    while (scanf("%d", &n), n) {
        for (int i = 1; i <= n; i++) {
            scanf("%d", &a[i]);
            if (a[i] > n) {
                a[i] = n + 1;
            }
        }

        for (int i = 0; i <= n + 1; i++) {
            last[i] = n + 1;
            vis[i] = 0;
        }
        for (int i = n; i >= 1; i--) {
            next[i] = last[a[i]];
            last[a[i]] = i;
        }

        int last = 0;
        for (int i = 1; i <= n; i++) {
            vis[a[i]] = 1;
            while (vis[last]) {
                last++;
            }
            mex[i] = last;
        }

        //for (int i = 1; i <= n; i++) {
            //printf("%d ", mex[i]);
        //}
        //printf("\n");

        makeTree(1, 1, n);
        ll ans = 0;
        for (int i = 1; i <= n; i++) {
            ans += getSum(1, 1, n, i, n);
            int x = getPos(1, 1, n, i, next[i] - 1, a[i]);
            //cout << "Ans = " << ans << endl;
            //cout << "Pos = " << x << endl;
            //cout << "Next = " << next[i] << endl;
            //cout << "A[i] = " << a[i] << endl;
            //cout << endl;
            if (1 <= x && x <= next[i] - 1) {
                change(1, 1, n, x, next[i] - 1, a[i]);
            }
        }
        printf("%I64d\n", ans);
    }

    return 0;
}



你可能感兴趣的:(线段树,Mex函数)