WOJ 1618 - Magic Array (线段树+单调栈)

题意:给定n(n<=500000)个数,A[1],A[2],...,A[n]。求所有子区间的 (最大值*最小值*长度)之和,对 10^9 取余数。


思路一:暴力     时间复杂度O(n^3)    超时

枚举所有子区间,然后遍历一遍来求最大值和最小值,然后累加答案。


思路二:稍加优化    时间复杂度O(n^2)   超时

枚举R,然后对于每个L<=R,用Max[L]表示区间[L,R]的最大值,Min[L]表示区间[L..R]的最小值。

在R增加1之后,每个L<=旧R的Max[L]和Min[L]只需要用A[R]去更新就行了。

代码:

//Simple solve
int Min[maxn],Max[maxn];
int SimpleSolve(){
	LL ANS=0;
	for(int R=1;R <=n;++R){
		Min[R]=Max[R]=A[R];
		for(int L=1;L <= R;++L){
			//更新Min[L]和Max[L] 
			Min[L]=min(Min[L],A[R]);
			Max[L]=max(Max[L],A[R]);
			//累加区间[L,R]的答案 
			ANS=(ANS+(LL)Min[L]*Max[L]%MOD*(R-L+1))%MOD;
		}
	}
	return ANS;
}


思路三: 线段树优化   时间复杂度:O(n*log(n))


思路三是思路二的线段树优化,所以在看思路三之前,请确保看懂了思路二。

思路二中,对于每个R,用A[R]更新了[1..R]的最大值和最小值,然后累加了[1..R]到R的答案。

如果更新和求和都使用线段树的话,时间复杂度就变成了O(n*log(n))。


首先,如何用线段树维护最大值。

对于每个R,要将Max[1..R]的数组中,所有小于A[R]的都更新成A[R]。

注意到Max[1],Max[2],...,Max[R]是单调非增的数列。

(Max[1]代表区间[1..R]的最大值,Max[2]代表区间[2..R]的最大值,明显Max[1]>=Max[2].)

所以,实际上,从某下标L开始,Max[L..R]都小于A[R],于是变成了线段树的区间修改:将[L..R]的数变成A[R]。

那么,怎么找到L值呢?在线段树的节点上多维护一个最大值的最大值,然后就可以判断了,具体看代码。


最小值同理。


然后来谈一谈线段树每个节点需要的变量,以下用m表示最小值,用M表示最大值,用L表示长度。

变量分为:标记量和统计量。


先来谈标记量,需要最小值标记,最大值标记,和长度标记,记做m,M,L。

毕竟是区间修改,所以需要三种标记。


统计量

然后明显要有(最大值*最小值*长度)的和,记做 smML。


在修改了最大值或最小值或长度时,要能够直接更新smML这个变量,

所以需要记录(最大值*最小值)的和,(最大值*长度)的和,(最小值*长度)的和,分别记做smM,sML,smL。


然后,在修改了最大值或最小值或长度时,要能够直接更新smM,sML,smL这三个变量,

需要记录最小值之和,最大值之和,长度之和,分别记做sm,sM,sL。


于是,s开头的求和的统计量需要7个。


在增加一个线段树区间的长度时,要更新sL,需要sL加上该区间的长度,所以用变量n表示该线段树区间的长度。


为了在更新最大值的时候可以找到左边界,需要记录最大值的最大值,记为MM。

同理,记录最小值的最小值,记为mm。


于是,线段树的一个节点,需要3个标记量 和 10个统计量。


节点的定义见下面代码:

//Segment Tree Node
struct Node{
	int m,M,L;//min max Len  3个标记量
	int mm,MM,n;//min of min,max of max , number of intervals
	int sm,sM,sL,smM,smL,sML,smML;//sum of products
	Node(){m=M=L=0;}
	Node operator+(const Node &B){//节点的统计量的更新
		Node &A = *this,C;
		C.mm = min(A.mm,B.mm);
		C.MM = max(A.MM,B.MM);
		C.n = A.n + B.n;
		C.sm   = (A.sm   + B.sm  ) % MOD;
		C.sM   = (A.sM   + B.sM  ) % MOD;
		C.sL   = (A.sL   + B.sL  ) % MOD;
		C.smM  = (A.smM  + B.smM ) % MOD;
		C.smL  = (A.smL  + B.smL ) % MOD;
		C.sML  = (A.sML  + B.sML ) % MOD;
		C.smML = (A.smML + B.smML) % MOD;
		return C;
	}
	void SetMax(int Max){//将该区间的最大值改为Max,修改最大值标记,以及对应的统计量
		M = MM = Max;
		sM   = (LL)n   * M % MOD;//  最大值的和               = 数量                * 最大值
		smM  = (LL)sm  * M % MOD;//(最小值*最大值)的和      = 最小值的和          * 最大值
		sML  = (LL)sL  * M % MOD;//(最大值*长度)的和        = 长度的和            * 最大值
		smML = (LL)smL * M % MOD;//(最大值*最小值*长度)的和 = (最小值*长度)的和 * 最大值
	}
	void SetMin(int Min){//将该区间的最小值改为Min,修改最小值标记,以及对应的统计量
		m = mm = Min;
		sm   = (LL)n   * m % MOD;
		smM  = (LL)sM  * m % MOD;
		smL  = (LL)sL  * m % MOD;
		smML = (LL)sML * m % MOD;
	}
	void AddLen(LL k){//该区间的长度增加k
		L += k;
		sL   = (sL   + k*n  )%MOD  ;
		smL  = (smL  + k*sm )%MOD ;
		sML  = (sML  + k*sM )%MOD ;
		smML = (smML + k*smM)%MOD;
	}
	void SetValue(LL V){//设置叶节点的值
		sL=n=1;
		smL=sML=sm=sM=mm=MM=V;
		smML=smM=V*V%MOD;
	}
};

还剩下最大值的更新的左端点怎么找的问题:

首先,要在[1..R]这个区间中,将所有的比A[R]小的值变成A[R]。

第一部分:常规的线段树区间判断,可以得到所有在[1..R]之内的区间。

第二部分:如果本区间的最大值小于等于V,直接将整个区间的最大值设置为V即可。

如果不是叶节点,那么进行递归调用,右区间一定要递归调用,左区间根据条件。

如果右区间的最大值小于等于V,那么左侧可能也有要更新的区间。所以要递归调用左侧。

具体见UpdateMax函数:

void UpdateMax(int X,int V,int l,int r,int rt){//[1,X]
	int L = rt << 1 , R = rt << 1 | 1 , m = (l + r) >> 1;
	if(r <= X){//第二部分,得到了[1..X]的区间之后 
		if(D[rt].MM <= V){//如果本区间的最大值小于等于V,直接将本区间的最大值设置为V 
			D[rt].SetMax(V);
			return;
		}
		if(l==r) return;//如果是叶节点,直接返回 
		PushDown(rt);
		UpdateMax(X,V,rs);//更新右侧 
		//如果右侧的最大值大于V,那么左侧不可能有需要更新的值,所以不需要递归左侧
		//否则,需要递归左侧 
		if(D[R].MM <= V) UpdateMax(X,V,ls);  
		PushUp(rt);
		return;
	}
	//第一部分:常规线段树区间判断,可以得到所有[1..X]之内的区间 
	PushDown(rt);
	UpdateMax(X,V,ls);
	if(X > m) UpdateMax(X,V,rs);
	PushUp(rt);
}

最后总结一下:

空间的问题:50万数据量,需要的线段树元素个数是1048576个,要是按一般的做法,直接四倍的话,就超出空间范围了。

时间的问题:开始写的是,先用线段树搜索出更新最大值的左侧下标L,再区间修改[L,R],然后超时了。

后来改成了在修改的时候顺便寻找修改边界,就快了许多。


-----------------------------------------------------------------------------  分割线  ------------------------------------------------------------------

上面说的方法,用时9.4秒。经过三个优化之后,可以达到1.5秒。

优化一:可以发现,对于每个R,如果A[R]>A[R-1]那么只需要更新最大值数组,最小值都不需要更新。节省了一半的线段树操作。

优化二:对于L值,前面的做法是把它当做数据来维护,其实不需要。可以直接在Query函数中计算,省掉了更新长度的操作,以及长度的懒惰标记。

优化三:前面的做法是在更新最大值的时候,利用统计量MM来找到左边界。其实可以用单调栈直接维护左边界。省去了左边界判断时间,以及mm,MM两个变量。


第一份代码如下:

/*
	Problem 1618 - Magic Array 
	56520KB   9420ms 
*/ 
#include 
#include 
#include 
#include 
#include 
#define LL long long
#define MOD 1000000000
#define maxn 500007
#define ls l,m,rt<<1
#define rs m+1,r,rt<<1|1
using namespace std;
//Input
int n,A[maxn];
//Segment Tree Node
struct Node{
	int m,M,L;//min max Len
	int mm,MM,n;//min of min,max of max , number of intervals
	int sm,sM,sL,smM,smL,sML,smML;//sum of products
	Node(){m=M=L=0;}
	Node operator+(const Node &B){
		Node &A = *this,C;
		C.mm = min(A.mm,B.mm);
		C.MM = max(A.MM,B.MM);
		C.n = A.n + B.n;
		C.sm   = (A.sm   + B.sm  ) % MOD;
		C.sM   = (A.sM   + B.sM  ) % MOD;
		C.sL   = (A.sL   + B.sL  ) % MOD;
		C.smM  = (A.smM  + B.smM ) % MOD;
		C.smL  = (A.smL  + B.smL ) % MOD;
		C.sML  = (A.sML  + B.sML ) % MOD;
		C.smML = (A.smML + B.smML) % MOD;
		return C;
	}
	void SetMax(int Max){
		M = MM = Max;
		sM   = (LL)n   * M % MOD;
		smM  = (LL)sm  * M % MOD;
		sML  = (LL)sL  * M % MOD;
		smML = (LL)smL * M % MOD;
	}
	void SetMin(int Min){
		m = mm = Min;
		sm   = (LL)n   * m % MOD;
		smM  = (LL)sM  * m % MOD;
		smL  = (LL)sL  * m % MOD;
		smML = (LL)sML * m % MOD;
	}
	void AddLen(LL k){
		L += k;
		sL   = (sL   + k*n  )%MOD  ;
		smL  = (smL  + k*sm )%MOD ;
		sML  = (sML  + k*sM )%MOD ;
		smML = (smML + k*smM)%MOD;
	}
	void SetValue(LL V){
		sL=n=1;
		smL=sML=sm=sM=mm=MM=V;
		smML=smM=V*V%MOD;
	}
}D[1048576];
void PushUp(int rt){D[rt] = D[rt<<1] + D[rt<<1|1];}
void PushDown(int rt){//Push down three marks
	int L = rt << 1 , R = rt << 1 | 1;
	if(D[rt].M){
		D[L].SetMax(D[rt].M);
		D[R].SetMax(D[rt].M);
		D[rt].M=0;
	}
	if(D[rt].m){
		D[L].SetMin(D[rt].m);
		D[R].SetMin(D[rt].m);
		D[rt].m=0;
	}
	if(D[rt].L){
		D[L].AddLen(D[rt].L);
		D[R].AddLen(D[rt].L);
		D[rt].L=0;
	}
}
void Build(int l,int r,int rt){
	if(l==r){
		D[rt].SetValue(A[l]);
		return;
	}
	int m=(l+r)>>1;
	Build(ls);
	Build(rs);
	PushUp(rt);
}
void UpdateMax(int X,int V,int l,int r,int rt){//[1,X]
	int L = rt << 1 , R = rt << 1 | 1 , m = (l + r) >> 1;
	if(r <= X){
		if(D[rt].MM <= V){
			D[rt].SetMax(V);
			return;
		}
		if(l==r) return;
		PushDown(rt);
		UpdateMax(X,V,rs);
		if(D[R].MM <= V) UpdateMax(X,V,ls);  
		PushUp(rt);
		return;
	}
	PushDown(rt);
	UpdateMax(X,V,ls);
	if(X > m) UpdateMax(X,V,rs);
	PushUp(rt);
}
void UpdateMin(int X,int V,int l,int r,int rt){//[1,X]
	int L = rt << 1 , R = rt << 1 | 1 , m = (l + r) >> 1;
	if(r <= X){
		if(D[rt].mm >= V){
			D[rt].SetMin(V);
			return;
		}
		if(l==r) return;
		PushDown(rt);
		UpdateMin(X,V,rs);
		if(D[R].mm >= V) UpdateMin(X,V,ls);  
		PushUp(rt);
		return;
	}
	PushDown(rt);
	UpdateMin(X,V,ls);
	if(X > m) UpdateMin(X,V,rs);
	PushUp(rt);
}
void UpdateLen(int X,int l,int r,int rt){//[1,X] 
	if(r <= X){
		D[rt].AddLen(1);
		return;
	}
	PushDown(rt);
	int m = (l + r) >> 1;
	UpdateLen(X,ls);
	if(X > m) UpdateLen(X,rs);
	PushUp(rt);
}
LL Query(int X,int l,int r,int rt){//求和
	if(r <= X){
		return D[rt].smML;
	}
	PushDown(rt);
	int m = (l + r) >> 1;
	LL ANS = Query(X,ls);
	if(X > m) ANS = (ANS + Query(X,rs)) % MOD;
	return ANS;
}
int main(void)
{
	while(~scanf("%d",&n)){
		for(int i=1;i<=n;++i) scanf("%d",&A[i]);
		Build(1,n,1);
		LL ANS = Query(1,1,n,1);
		for(int R=2;R <= n;++R){
			UpdateMax(R,A[R],1,n,1);//更新最大值
			UpdateMin(R,A[R],1,n,1);//更新最小值
			UpdateLen(R-1,1,n,1);//更新长度
			ANS = (ANS + Query(R,1,n,1)) % MOD;//累加答案
		}
		printf("%d\n",(int)ANS);
	}
	return 0;
}



优化后代码如下:

/*
	Problem 1618 - Magic Array 
	Memory: 44260KB  Time: 1500ms
*/ 
#include 
 #include 
 #include 
 #include 
 #include 
 #define LL long long
 #define MOD 1000000000
 #define maxn 500007
 #define ls l,m,rt<<1
 #define rs m+1,r,rt<<1|1
 using namespace std;
//Input
 int n,A[maxn];
 int Min[maxn],IMin;
int Max[maxn],IMax;
//Segment Tree Node
 struct Node{
     int m,M;//min max
     int n;//number of intervals
     int sm,sM,sL,smM,smL,sML,smML;//sum of products
     Node(){m=M=0;}
     Node operator+(const Node &B)const{
         const Node &A = *this;
         Node C; 
         C.n = A.n + B.n;
         C.sm  = (A.sm  + B.sm  ) % MOD;
         C.sM  = (A.sM  + B.sM  ) % MOD;
         C.smM  = (A.smM  + B.smM ) % MOD;
         C.sL  = (A.sL  + (LL)A.n  * B.n + B.sL  ) % MOD;
         C.smL  = (A.smL  + (LL)A.sm  * B.n + B.smL ) % MOD;
         C.sML  = (A.sML  + (LL)A.sM  * B.n + B.sML ) % MOD;
         C.smML = (A.smML + (LL)A.smM * B.n + B.smML) % MOD;
         return C;
     }
     void SetMax(int Max){
         M = Max;
         sM  = (LL)n  * M % MOD;
         smM  = (LL)sm  * M % MOD;
         sML  = (LL)sL  * M % MOD;
         smML = (LL)smL * M % MOD;
     }
     void SetMin(int Min){
         m = Min;
         sm  = (LL)n  * m % MOD;
         smM  = (LL)sM  * m % MOD;
         smL  = (LL)sL  * m % MOD;
         smML = (LL)sML * m % MOD;
     }
     void SetValue(LL V){
         sL=n=1;
         smL=sML=sm=sM=V;
         smML=smM=V*V%MOD;
     }
 }D[1048576];
void PushUp(int rt){D[rt] = D[rt<<1] + D[rt<<1|1];}
void PushDown(int rt){//Push down three marks
     int L = rt << 1 , R = rt << 1 | 1;
     if(D[rt].M){
         D[L].SetMax(D[rt].M);
         D[R].SetMax(D[rt].M);
         D[rt].M=0;
     }
     if(D[rt].m){
         D[L].SetMin(D[rt].m);
         D[R].SetMin(D[rt].m);
         D[rt].m=0;
     }
 }
void Build(int l,int r,int rt){
     if(l==r){
         D[rt].SetValue(A[l]);
         return;
     }
     int m=(l+r)>>1;
     Build(ls);
     Build(rs);
     PushUp(rt);
 }
 void UpdateMax(int L,int R,int V,int l,int r,int rt){
	if(L <= l && r <= R){
		D[rt].SetMax(V);
		return;
	}
	PushDown(rt);
	int m = (l + r) >> 1;
	if(L <= m) UpdateMax(L,R,V,ls);
	if(R >  m) UpdateMax(L,R,V,rs);
	PushUp(rt);
}
void UpdateMin(int L,int R,int V,int l,int r,int rt){
	if(L <= l && r <= R){
		D[rt].SetMin(V);
		return;
	}
	PushDown(rt);
	int m = (l + r) >> 1;
	if(L <= m) UpdateMin(L,R,V,ls);
	if(R >  m) UpdateMin(L,R,V,rs);
	PushUp(rt);
}
 Node Query(int X,int l,int r,int rt){
     if(r <= X){
         return D[rt];
     }
     PushDown(rt);
     int m = (l + r) >> 1;
     Node ANS = Query(X,ls);
     if(X > m) ANS = ANS + Query(X,rs);
     return ANS;
 }

int main(void)
{
    while(~scanf("%d",&n)){
        for(int i=1;i<=n;++i) scanf("%d",&A[i]);
        Build(1,n,1);
        Min[0]=Max[0]=IMin=IMax=0;
        LL ANS = 0;
        for(int R=1;R <= n;++R){
        	while(IMin && A[R]<=A[Min[IMin]]) --IMin;
			Min[++IMin]=R;
			while(IMax && A[R]>=A[Max[IMax]]) --IMax;
			Max[++IMax]=R;
            if(A[R]>A[R-1]) UpdateMax(Max[IMax-1]+1,R,A[R],1,n,1);
            else UpdateMin(Min[IMin-1]+1,R,A[R],1,n,1);
            ANS = (ANS + Query(R,1,n,1).smML) % MOD;
        }
        printf("%d\n",(int)ANS);
    }
    return 0;
}





你可能感兴趣的:(线段树/平衡树)