(这次的代码写得非常的丑,因为敲代码时的环境非常的乱,大家可以不用看了。。。。。。)
题目大意:
给一个数字串。
然后定义一个函数Mex({A}) = A中没有的最小的非负整数。
即若A = {0、1、3},则Mex(A) = 2
然后要求数字串的所有连续子串的Mex值之和。
总结:
这一题我想了很久。这一题的解放真是太有意思了,处于我思维的一个盲点。
我一开始想过处理出Mex = 1、2、3、.......、n-1、n 的子串数,然后统计答案。
然后我从另外一个方向想:
依次去统计
[n, n]
[n-1, n-1] , [n-1, n]
[n-1, n-2] , [n-2, n-1] , [n-2, n]
......
[1 , 1] , [1 , 2] , [1 , 3] , [1 , 4] , [1 , 5] , ..... , [1, n-1] , [1 , n]
然后[i, x] 到[i-1, x] ,用什么数据结构去维护。
从这个方向去想也一直没有结果,后面看了网上的题解,才发现原来还能够这么做。
做法是先用O(N)的时间计算出[1, 1] , [1 , 2] , [1 , 3] , ...... , [1 , n-1] , [1 , n] 的值
然后在其基础上,用O(logn)的时间复杂度计算出
[2, 2] , [2, 3] , [2 , 4] , ....... , [2 , n-1] , [2 , n]的值。
这样转换n次,就能算出所有连续子串的合,并且时间复杂度是O(N*logN)
具体做法:
先用O(n)的时间计算出[1, 1] , [1, 2] , [1, 3] , [1, 4], ..... , [1, n-1] , [1, n]的Mex值。
具体代码:
value = 0;
for (i = 1; i <= n; i ++)
{
if (a[i] <= n)
flag[a[i]] = 1;
while (flag[value] == 1)
value ++;
h[i] = value;
}
h[i]表示的就是[1, i]的Mex值
由于Mex的性质,h[]的值是单调不下降序列。
然后考虑将h[i]从表示[1, i]的Mex值转换到[2, i]的Mex值。
考虑所有的连续子串减去a[1]以后,对h[]有什么影响。
若[1, i]去掉a[1]以后,[2, i]内仍然含有等于a[1]的值,则h[i]的值不发生改变。
若[1, i]去掉a[1]以后,[2, i]内不含有等于a[1]的值,则h[i]的值不能够大于a[i],即若此时h[i] > a[i] , 令h[i] = a[i]
这样就能将h[i]从表示[1, i]的Mex值转换到[2, i]的Mex值,此时的h[i]仍然是单调不下降的。
直接转换的时间复杂度是O(n)的,但由于h[]是一个单调不下降序列,所以可以使用线段树,使转换在O(logn)的时间内完成。
这个转换需要完成两个操作。
操作一,找出满足h[k] > a[i] 的最小的k
这个可以在线段树中加了一个max标记,表示这个区间内最大的h[]值,利用这个标记,可以在线段树上在O(logn)的时间内找出k值
int find(int x, int y, int value, int t) //在[x, y]中查询a[k] > value的k
{
int ans1, ans2;
if (x > lt[t].y || y < lt[t].x)
return n + 1;
if (lt[t].max <= value)
return n + 1;
if (lt[t].flag == 1 && lt[t].x != lt[t].y)
{
lt[t*2].get_property(lt[t]);
lt[t*2+1].get_property(lt[t]);
}
if (x <= lt[t].x && lt[t].y <= y)
{
if (lt[t].x == lt[t].y)
return lt[t].x;
else if (lt[t*2].max > value)
return find(x, y, value, t*2);
else return find(x, y, value, t*2+1);
}
else if (lt[t].x != lt[t].y)
{
ans1 = find(x, y, value, t*2);
ans2 = find(x, y, value, t*2+1);
return min(ans1, ans2);
}
}
操作二:在操作一的基础上,若k < a[i].next (a[i].next表示下一个等于a[i]的数字的下标),则将[k, a[i].next-1]上的h[i]值赋值为a[i]
这个因为每个的h[i]值只会下降不会上升,所以可以利用flag标记(表示整个区间是否等于同一个值)来进行线段树节点值传递。
这样就可以在O(logn)的时间内更新线段树的的max,sum,value,flag。(最大值、总和、当flag==1时,整个区间中的h[]值,整段区间相同标志)
我的更新操作代码:
void updata(int x, int y, int value, int t)
{
if (x > lt[t].y || y < lt[t].x)
return;
if (x <= lt[t].x && lt[t].y <= y)
{
lt[t].value = value;
lt[t].flag = 1;
lt[t].sum = 1ll * lt[t].value * (lt[t].y - lt[t].x + 1);
lt[t].max = value;
return;
}
if (lt[t].x != lt[t].y)
{
if (lt[t].flag == 1)
{
lt[t*2].get_property(lt[t]);
lt[t*2+1].get_property(lt[t]);
lt[t].flag = 0;
}
updata(x, y, value, t*2);
updata(x, y, value, t*2+1);
lt[t].max = max(lt[t*2].max, lt[t*2+1].max);
lt[t].sum = lt[t*2].sum + lt[t*2+1].sum;
}
}
最后贴上我全部的代码:
#include
#include
#include
using namespace std;
const int MAXN = 2e5 + 100;
class node
{
public:
int x, y, value, flag, max;
long long sum;
void get_property(node &b)
{
value = b.value;
flag = b.flag;
max = b.max;
sum = 1ll * (y - x + 1) * value;
}
};
node lt[MAXN*5];
int a[MAXN], flag[MAXN], h[MAXN];
int p_next[MAXN], head[MAXN];
int n;
void build(int x, int y, int t)
{
lt[t].x = x;
lt[t].y = y;
lt[t].value = 0;
lt[t].flag = 0;
lt[t].sum = 0;
lt[t].max = 0;
if (x != y)
{
int mid = (x + y) / 2;
build(x, mid, t*2);
build(mid+1, y, t*2+1);
}
}
void updata(int x, int y, int value, int t)
{
if (x > lt[t].y || y < lt[t].x)
return;
if (x <= lt[t].x && lt[t].y <= y)
{
lt[t].value = value;
lt[t].flag = 1;
lt[t].sum = 1ll * lt[t].value * (lt[t].y - lt[t].x + 1);
lt[t].max = value;
return;
}
if (lt[t].x != lt[t].y)
{
if (lt[t].flag == 1)
{
lt[t*2].get_property(lt[t]);
lt[t*2+1].get_property(lt[t]);
lt[t].flag = 0;
}
updata(x, y, value, t*2);
updata(x, y, value, t*2+1);
lt[t].max = max(lt[t*2].max, lt[t*2+1].max);
lt[t].sum = lt[t*2].sum + lt[t*2+1].sum;
}
}
int find(int x, int y, int value, int t) //在[x, y]中查询a[k] > value的k
{
int ans1, ans2;
if (x > lt[t].y || y < lt[t].x)
return n + 1;
if (lt[t].max <= value)
return n + 1;
if (lt[t].flag == 1 && lt[t].x != lt[t].y)
{
lt[t*2].get_property(lt[t]);
lt[t*2+1].get_property(lt[t]);
}
if (x <= lt[t].x && lt[t].y <= y)
{
if (lt[t].x == lt[t].y)
return lt[t].x;
else if (lt[t*2].max > value)
return find(x, y, value, t*2);
else return find(x, y, value, t*2+1);
}
else if (lt[t].x != lt[t].y)
{
ans1 = find(x, y, value, t*2);
ans2 = find(x, y, value, t*2+1);
return min(ans1, ans2);
}
}
void init()
{
int i;
for (i = 0; i <= n; i ++)
{
flag[i] = 0;
head[i] = n + 1;
}
}
int main()
{
int i, pos_x, pos_y, value;
long long ans;
while (scanf("%d",&n))
{
if (n == 0)
break;
init();
for (i = 1; i <= n; i ++)
{
scanf("%d", &a[i]);
if (a[i] > n)
a[i] = n + 10;
}
value = 0;
for (i = 1; i <= n; i ++)
{
if (a[i] <= n)
flag[a[i]] = 1;
while (flag[value] == 1)
value ++;
h[i] = value;
}
for (i = n; i >= 1; i --)
if (a[i] <= n)
{
p_next[i] = head[a[i]];
head[a[i]] = i;
}
build(1, n, 1);
for (i = 1; i <= n; i ++)
updata(i, i, h[i], 1);
ans = 0;
for (i = 1; i <= n; i ++)
{
ans += lt[1].sum;
if (a[i] <= n)
{
pos_y = p_next[i];
pos_x = find(i, pos_y, a[i], 1);
if (pos_x <= pos_y)
updata(pos_x, pos_y-1, a[i], 1);
}
}
printf("%lld\n", ans);
}
return 0;
}