hdu4747 Mex (线段树 好题)

(这次的代码写得非常的丑,因为敲代码时的环境非常的乱,大家可以不用看了。。。。。。)

 

题目大意:

给一个数字串。

然后定义一个函数Mex({A}) = A中没有的最小的非负整数。

即若A = {0、1、3},则Mex(A) = 2

然后要求数字串的所有连续子串的Mex值之和。

 

总结:

这一题我想了很久。这一题的解放真是太有意思了,处于我思维的一个盲点。

我一开始想过处理出Mex = 1、2、3、.......、n-1、n 的子串数,然后统计答案。

然后我从另外一个方向想:

依次去统计

[n, n]

[n-1, n-1] , [n-1, n]

[n-1, n-2] , [n-2, n-1] , [n-2, n]

......

[1 , 1] , [1 , 2] , [1 , 3] , [1 , 4] , [1 , 5] , ..... , [1, n-1] , [1 , n]

然后[i, x] 到[i-1, x] ,用什么数据结构去维护。

从这个方向去想也一直没有结果,后面看了网上的题解,才发现原来还能够这么做。

 

做法是先用O(N)的时间计算出[1, 1] , [1 , 2] , [1 , 3] , ...... , [1 , n-1] , [1 , n] 的值

然后在其基础上,用O(logn)的时间复杂度计算出

[2, 2] , [2, 3] , [2 , 4] , ....... , [2 , n-1] , [2 , n]的值。

这样转换n次,就能算出所有连续子串的合,并且时间复杂度是O(N*logN)

 

 

具体做法:

先用O(n)的时间计算出[1, 1] , [1, 2] , [1, 3] , [1, 4], ..... , [1, n-1] , [1, n]的Mex值。

具体代码:

		value = 0;
		for (i = 1; i <= n; i ++)
		{
			if (a[i] <= n) 
				flag[a[i]] = 1;
			while (flag[value] == 1)
				value ++;
			h[i] = value;
		}


 

h[i]表示的就是[1, i]的Mex值

由于Mex的性质,h[]的值是单调不下降序列。

然后考虑将h[i]从表示[1, i]的Mex值转换到[2, i]的Mex值。

考虑所有的连续子串减去a[1]以后,对h[]有什么影响。

若[1, i]去掉a[1]以后,[2, i]内仍然含有等于a[1]的值,则h[i]的值不发生改变。

若[1, i]去掉a[1]以后,[2, i]内不含有等于a[1]的值,则h[i]的值不能够大于a[i],即若此时h[i] > a[i] , 令h[i] = a[i]

这样就能将h[i]从表示[1, i]的Mex值转换到[2, i]的Mex值,此时的h[i]仍然是单调不下降的。

直接转换的时间复杂度是O(n)的,但由于h[]是一个单调不下降序列,所以可以使用线段树,使转换在O(logn)的时间内完成。

这个转换需要完成两个操作。

操作一,找出满足h[k] > a[i] 的最小的k

这个可以在线段树中加了一个max标记,表示这个区间内最大的h[]值,利用这个标记,可以在线段树上在O(logn)的时间内找出k值

int find(int x, int y, int value, int t)	//在[x, y]中查询a[k] > value的k 
{
	int ans1, ans2;
	if (x > lt[t].y || y < lt[t].x)
		return n + 1;
	if (lt[t].max <= value)
		return n + 1;
	if (lt[t].flag == 1 && lt[t].x != lt[t].y)
	{
		lt[t*2].get_property(lt[t]);
		lt[t*2+1].get_property(lt[t]);
	}
	if (x <= lt[t].x && lt[t].y <= y)
	{
		if (lt[t].x == lt[t].y)
			return lt[t].x;
		else if (lt[t*2].max > value)
			return find(x, y, value, t*2);
		else return find(x, y, value, t*2+1);
	}
	else if (lt[t].x != lt[t].y)
	{
		ans1 = find(x, y, value, t*2);
		ans2 = find(x, y, value, t*2+1);
		return min(ans1, ans2);
	}
}


 

操作二:在操作一的基础上,若k < a[i].next (a[i].next表示下一个等于a[i]的数字的下标),则将[k, a[i].next-1]上的h[i]值赋值为a[i]

这个因为每个的h[i]值只会下降不会上升,所以可以利用flag标记(表示整个区间是否等于同一个值)来进行线段树节点值传递。

这样就可以在O(logn)的时间内更新线段树的的max,sum,value,flag。(最大值、总和、当flag==1时,整个区间中的h[]值,整段区间相同标志)

我的更新操作代码:

void updata(int x, int y, int value, int t)
{
	if (x > lt[t].y || y < lt[t].x)
		return;
	if (x <= lt[t].x && lt[t].y <= y)
	{
		lt[t].value = value;
		lt[t].flag = 1;
		lt[t].sum = 1ll * lt[t].value * (lt[t].y - lt[t].x + 1);
		lt[t].max = value;
		return;
	}
	if (lt[t].x != lt[t].y)
	{
		if (lt[t].flag == 1)
		{
			lt[t*2].get_property(lt[t]);
			lt[t*2+1].get_property(lt[t]);
			lt[t].flag = 0;
		}
		updata(x, y, value, t*2);
		updata(x, y, value, t*2+1);
		lt[t].max = max(lt[t*2].max, lt[t*2+1].max);
		lt[t].sum = lt[t*2].sum + lt[t*2+1].sum;
	}
}


 

最后贴上我全部的代码:

#include 
#include 
#include 
using namespace std;

const int MAXN = 2e5 + 100;

class node
{
	public:
		int x, y, value, flag, max;
		long long sum;
		void get_property(node &b)
		{
			value = b.value;
			flag = b.flag;
			max = b.max;
			sum = 1ll * (y - x + 1) * value;
		}
};
node lt[MAXN*5];
int a[MAXN], flag[MAXN], h[MAXN];
int p_next[MAXN], head[MAXN];
int n;

void build(int x, int y, int t)
{
	lt[t].x = x;
	lt[t].y = y;
	lt[t].value = 0;
	lt[t].flag = 0;
	lt[t].sum = 0;
	lt[t].max = 0;
	if (x != y)
	{
		int mid = (x + y) / 2;
		build(x, mid, t*2);
		build(mid+1, y, t*2+1);
	}
}
void updata(int x, int y, int value, int t)
{
	if (x > lt[t].y || y < lt[t].x)
		return;
	if (x <= lt[t].x && lt[t].y <= y)
	{
		lt[t].value = value;
		lt[t].flag = 1;
		lt[t].sum = 1ll * lt[t].value * (lt[t].y - lt[t].x + 1);
		lt[t].max = value;
		return;
	}
	if (lt[t].x != lt[t].y)
	{
		if (lt[t].flag == 1)
		{
			lt[t*2].get_property(lt[t]);
			lt[t*2+1].get_property(lt[t]);
			lt[t].flag = 0;
		}
		updata(x, y, value, t*2);
		updata(x, y, value, t*2+1);
		lt[t].max = max(lt[t*2].max, lt[t*2+1].max);
		lt[t].sum = lt[t*2].sum + lt[t*2+1].sum;
	}
}
int find(int x, int y, int value, int t)	//在[x, y]中查询a[k] > value的k 
{
	int ans1, ans2;
	if (x > lt[t].y || y < lt[t].x)
		return n + 1;
	if (lt[t].max <= value)
		return n + 1;
	if (lt[t].flag == 1 && lt[t].x != lt[t].y)
	{
		lt[t*2].get_property(lt[t]);
		lt[t*2+1].get_property(lt[t]);
	}
	if (x <= lt[t].x && lt[t].y <= y)
	{
		if (lt[t].x == lt[t].y)
			return lt[t].x;
		else if (lt[t*2].max > value)
			return find(x, y, value, t*2);
		else return find(x, y, value, t*2+1);
	}
	else if (lt[t].x != lt[t].y)
	{
		ans1 = find(x, y, value, t*2);
		ans2 = find(x, y, value, t*2+1);
		return min(ans1, ans2);
	}
}
void init()
{
	int i;
	for (i = 0; i <= n; i ++)
	{
		flag[i] = 0;
		head[i] = n + 1;
	}
}
int main()
{
	int i, pos_x, pos_y, value;
	long long ans;
	while (scanf("%d",&n))
	{
		if (n == 0)
			break;
		init();
		for (i = 1; i <= n; i ++)
		{
			scanf("%d", &a[i]);
			if (a[i] > n)
				a[i] = n + 10;
		}
		value = 0;
		for (i = 1; i <= n; i ++)
		{
			if (a[i] <= n) 
				flag[a[i]] = 1;
			while (flag[value] == 1)
				value ++;
			h[i] = value;
		}
		for (i = n; i >= 1; i --)
			if (a[i] <= n)
			{
				p_next[i] = head[a[i]];
				head[a[i]] = i;
			}
		build(1, n, 1);
		for (i = 1; i <= n; i ++)
			updata(i, i, h[i], 1);
		ans = 0;
		for (i = 1; i <= n; i ++)
		{
			ans += lt[1].sum;
			if (a[i] <= n)
			{
				pos_y = p_next[i];
				pos_x = find(i, pos_y, a[i], 1);
				if (pos_x <= pos_y)
					updata(pos_x, pos_y-1, a[i], 1);
			}
		}
		printf("%lld\n", ans);
	}
	return 0;
}


 

 

 

 

你可能感兴趣的:(acm训练,线段树)