题意:有一个序列a[],mex(L, R)表示区间a在区间[L, R]上第一个没出现的最小非负整数,对于序列a[],求所有的mex(L, R)的和(1 <= L <= R <= n,1 <= n <= 200000,0 <= ai <= 10^9)。
题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4747
——>>线段树就是如此的神~
求出所有的mex(1, i);接着删去第1个结点,就是所有的mex(2, i);接着再删去第1个结点,就是所有的mex(3, i);……最后就是mex(n, n),求和即是答案。
而维护删除结点后的信息,正是线段树的拿手好戏。
对于每个线段树结点(o, L, R),设mexv[o]表示mex(left, R),这里的left表示第一个数的下标,初始1,随着删除的进行,left递增。
设sumv[o]表示区间[L, R]上的所有mexv的和。
当删除了一个结点a[i]时,如果a[i] < mexv[1],说明a[i]被删后一定会使某个区间的mex变成a[i],这个区间就是第一个mexv比a[i]大的i到下个a[i]出现的前一位。
~~
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; const int N = 200000; const int maxn = 200000 + 10; int n, a[maxn], vis[maxn], nxt[maxn], setv[maxn<<2]; //nxt[i]表示下一个a[i]出现的位置,没有为n+1 long long mex1[maxn], mexv[maxn<<2], sumv[maxn<<2]; //mex1[i]表示mex(1, i) void read(){ for(int i = 1; i <= n; i++){ scanf("%d", &a[i]); if(a[i] >= N) a[i] = N; } } void init(){ memset(vis, 0, sizeof(vis)); } void getMex1(){ //获取mex1[] int ret = 0; //因为mex1[]是递增的,所以ret = 0放在for的外面(放里面就O(n^2)了) for(int i = 1; i <= n; i++){ vis[a[i]] = 1; for(; vis[ret]; ret++); mex1[i] = ret; } } void getNxt(){ //获取nxt[] for(int i = 0; i <= N; i++) vis[i] = n+1; //初始化为n+1 for(int i = n; i >= 1; i--){ //注意:从右往左!!! nxt[i] = vis[a[i]]; vis[a[i]] = i; } } void maintain(int o, int L, int R){ //维护函数 int lc = o << 1, rc = lc | 1; mexv[o] = max(mexv[lc], mexv[rc]); sumv[o] = sumv[lc] + sumv[rc]; } void build(int o, int L, int R){ //建树 setv[o] = -1; //赋值标记 if(L == R){ mexv[o] = sumv[o] = mex1[L]; return; } int M = (L + R) >> 1; build(o<<1, L, M); build(o<<1|1, M+1, R); maintain(o, L, R); } inline void get(int o, int L, int R, int v){ //单点赋值 mexv[o] = v; sumv[o] = v * (R - L + 1); setv[o] = v; } void pushdown(int o, int L, int R){ //下传机制 if(setv[o] != -1){ int M = (L + R) >> 1; int lc = o << 1, rc = lc | 1; get(lc, L, M, setv[o]); get(rc, M+1, R, setv[o]); setv[o] = -1; } } int Upper_bound(int o, int L, int R, int v){ //找出第一个mexv比v大的下标 if(L == R) return L; pushdown(o, L, R); int M = (L + R) >> 1; int lc = o << 1, rc = lc | 1; return mexv[lc] > v ? Upper_bound(lc, L, M, v) : Upper_bound(rc, M+1, R, v); } void update(int o, int L, int R, int ql, int qr, int v){ //区间赋值:[ql, qr]赋为v if(ql <= L && R <= qr){ get(o, L, R, v); return; } pushdown(o, L, R); int M = (L + R) >> 1; int lc = o << 1, rc = lc | 1; if(ql <= M) update(lc, L, M, ql, qr, v); if(qr > M) update(rc, M+1, R, ql, qr, v); maintain(o, L, R); } void solve(){ //解决函数 long long ret = 0; for(int i = 1; i <= n; i++){ //枚举起点 ret += sumv[1]; //累加 if(a[i] < mexv[1]){ //这种情况下删了a[i]会使区间[Upper_bound, nxt[i]-1]的mex变成a[i] int ql = Upper_bound(1, 1, n, a[i]), qr = nxt[i] - 1; if(ql <= qr) update(1, 1, n, ql, qr, a[i]); //这个判断是必要的,若有数据:2 1 2 0 0,删第一个2时 } int ql = i, qr = i; update(1, 1, n, ql, qr, 0); //删除起点产生新起点 if(!mexv[1]) break; //剪枝 } printf("%I64d\n", ret); } int main() { while(scanf("%d", &n) == 1 && n){ read(); init(); getMex1(); getNxt(); build(1, 1, n); solve(); } return 0; }