hdu 4747 Mex(线段树区间更新+二分)

题目:

给出一个序列,mex{}表示集合中没有出现的最小的自然数。
然后求 mex(i,j)

解析:

思路转载自 cxlove

考虑左端点固定时的所有区间的mex值,这个序列是一个非递减序列,这点首先要明白。

初始时,先求出mex[j]表示mex(1, j)。(可以用map求出)
对于每一个左端点i,就是一个区间求和。(可以利用线段树维护)

现在需要考虑的是左端点的改变对于序列的影响。

即左端点i,从 i -> i + 1,mex[j]的改变……,即删去 ai 对于序列的影响。
如果 a[j]=a[i]j>ia[k]=a[i](j>k>i) ,那么 j 即 a[i] 下一次出现的位置。(也可利用map,求出 j 的位置)

根据mex的定义,我们知道 mex[k](k>=j) 不会改变,因为删掉的 ai 还是存在于序列当中,所以不受影响。

之后需要考虑的是 i+1 j1 这段区间的mex{}值。
删去了 ai 之后,使得原先mex{}值大于 ai 的,都会更新成 ai
很好理解。因为是没有出现的最小的,然而 ai 更小。

之前说过这是一个非递减的序列,所以原先mex值大于 ai 的也是一段连续的区间,所以我们可以找到最靠左的位置r,使得 a[i] < mex[r]。(二分查找最靠左的位置)
那么 r 到 j-1 这段区间的mex值,便会更新为a[i]。

所以全部搞定。用线段树维护一下mex序列,区间更新,区间求和,然后一个查找就可以了。

my code

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <map>
#include <set>
#define ls (o<<1)
#define rs (o<<1|1)
#define lson ls, L, M
#define rson rs, M+1, R
#define MID (L + R) >> 1
#define LEN(L, R) ((R) - (L) + 1)
using namespace std;
typedef long long ll;
const int INF = 0x3f3f3f3f;
const int N = 200005;
ll a[N];
ll sumv[N<<2], cov[N<<2];
int n, mex[N], jump[N];
map<ll, int> mp;

inline void pushDown(int o, int L, int R) {
    if(cov[o] != -1) {
        int M = MID;
        cov[ls] = cov[rs] = cov[o];
        sumv[ls] = LEN(L, M) * cov[o];
        sumv[rs] = LEN(M+1, R) * cov[o];
        cov[o] = -1;
    }
}

inline void pushUp(int o) {
    sumv[o] = sumv[ls] + sumv[rs];
}

void build(int o, int L, int R) {
    cov[o] = -1;
    sumv[o] = 0;
    if(L == R) {
        cov[o] = sumv[o] = mex[L];
        return ;
    }
    int M = MID;
    build(lson);
    build(rson);
    pushUp(o);
}

void modify(int o, int L, int R, int ql, int qr, ll val) {
    if(ql <= L && R <= qr) {
        cov[o] = val;
        sumv[o] = LEN(L, R) * val;
        return ;
    }
    pushDown(o, L, R);
    int M = MID;
    if(ql <= M) modify(lson, ql, qr, val);
    if(qr > M) modify(rson, ql, qr, val);
    pushUp(o);
}

ll query(int o, int L, int R, int ql, int qr) {
    if(ql <= L && R <= qr) return sumv[o];
    pushDown(o, L, R);
    int M = MID;
    ll ret = 0;
    if(ql <= M) ret += query(lson, ql, qr);
    if(qr > M) ret += query(rson, ql, qr);
    return ret;
}

ll get(int o, int L, int R, int pos) {
    if(L == R) return sumv[o];
    pushDown(o, L, R);
    int M = MID;
    if(pos <= M) return get(lson, pos);
    else return get(rson, pos);
}

void getMex() {
    mp.clear();
    int tmp = 0;
    for(int i = 1; i <= n; i++) {
        mp[a[i]] = 1;
        while(mp.find(tmp) != mp.end())
            tmp++;
        mex[i] = tmp;
    }

    mp.clear();
    for(int i = n; i >= 1; i--) {
        if(mp.find(a[i]) == mp.end())
            jump[i] = n+1;
        else jump[i] = mp[a[i]];
        mp[a[i]] = i;
    }
}

int search(int start, int end, int lim) {
    int L = start, R = end+1;
    while(L < R) {
        int M = MID;
        ll tmp = get(1, 1, n, M);
        if(tmp > lim) R = M;
        else L = M + 1;
    }
    return L;
}

ll cal() {
    int ql, qr;
    ll ret = query(1, 1, n, 1, n);
    for(int i = 2; i <= n; i++) {
        qr = jump[i-1] - 1;
        ql = search(i, qr, a[i-1]);
        if(ql <= qr)
            modify(1, 1, n, ql, qr, a[i-1]);
        ret += query(1, 1, n, i, n);
    }
    return ret;
}

int main() {
    while(~scanf("%d", &n) && n) {
        for(int i = 1; i <= n; i++) {
            scanf("%lld", &a[i]);
        }
        getMex();
        build(1, 1, n);    
        printf("%lld\n", cal());
    }
    return 0;
}

你可能感兴趣的:(HDU,4747)