首先,要先搞清楚线段树需要维护的几个点:
左右端点:l,r;
区间和:sum;
紧靠左端点的最大子段和:lm;
紧靠右端点的最大子段和:rm;
最大连续子段和:mx;
。。。。。
接下来就是我们该如何去维护这些值:
sum:很明显就是左右子节点之和;
lm:要么最大值只在左子节点lm,要么左子节点sum加上右子节点的lm;(要保证连续)
同理rm:右子节点rm或者右子节点sum加左子节点rm;
mx:最大连续子段和也是要么单独存在于左或右子节点中,要么跨过这两个段:左子节点rm+右子节点lm;
void update(int k)
{
tr[k].sum=tr[k<<1].sum+tr[k<<1|1].sum;
tr[k].mx=max(tr[k<<1].mx,tr[k<<1|1].mx);
tr[k].mx=max(tr[k].mx,tr[k<<1].rm+tr[k<<1|1].lm);
tr[k].lm=max(tr[k<<1].lm,tr[k<<1].sum+tr[k<<1|1].lm);
tr[k].rm=max(tr[k<<1|1].rm,tr[k<<1|1].sum+tr[k<<1].rm);
}
最后如何查询给定一个区间的最大值:
其实和上面的思想差不多,最大连续子段和要么单独存在左或右子节点中,要么跨段存在,如果当前节点在所查范围内,return 此节点的所有值就可以了。
node ask(int k,int l,int r)
{
if(l<=tr[k].l&&r>=tr[k].r){
return tr[k];
}
int mid=(tr[k].l+tr[k].r)>>1;
if(r<=mid)return ask(k<<1,l,r);
if(l>mid)return ask(k<<1|1,l,r);
else { //如果不单独存在左右子节点,就需要单独查询,然后就和上面update一样,合并一下就ok了
node a,b,c; //需要三个值还是查询一样的思想
a=ask(k<<1,l,mid);
b=ask(k<<1|1,mid+1,r);
c.sum=a.sum+b.sum;
c.lm=max(a.lm,a.sum+b.lm);
c.rm=max(b.rm,b.sum+a.rm);
c.mx=max(a.rm+b.lm,max(a.mx,b.mx));
return c;
}
}
下面完整代码
#include
#include
#include
#include
#define maxn 50010
typedef long long ll;
using namespace std;
struct node{
int l,r;
ll lm,rm,sum,mx;
}tr[maxn<<2];
ll a[maxn];
void update(int k)
{
tr[k].sum=tr[k<<1].sum+tr[k<<1|1].sum;
tr[k].mx=max(tr[k<<1].mx,tr[k<<1|1].mx);
tr[k].mx=max(tr[k].mx,tr[k<<1].rm+tr[k<<1|1].lm);
tr[k].lm=max(tr[k<<1].lm,tr[k<<1].sum+tr[k<<1|1].lm);
tr[k].rm=max(tr[k<<1|1].rm,tr[k<<1|1].sum+tr[k<<1].rm);
}
void build(int k,int l,int r)
{
tr[k].l=l,tr[k].r=r;
if(l==r){
ll v;
scanf("%lld",&v);
tr[k].lm=tr[k].rm=tr[k].sum=tr[k].mx=v;
return ;
}
int mid=(l+r)>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
update(k);
}
node ask(int k,int l,int r)
{
if(l<=tr[k].l&&r>=tr[k].r){
return tr[k]; //将当前节点所有值返回,以便后续合并,最后就是要查询的结果
}
int mid=(tr[k].l+tr[k].r)>>1;
if(r<=mid)return ask(k<<1,l,r);
if(l>mid)return ask(k<<1|1,l,r);
else {
node a,b,c;
a=ask(k<<1,l,mid);
b=ask(k<<1|1,mid+1,r);
c.sum=a.sum+b.sum;
c.lm=max(a.lm,a.sum+b.lm);
c.rm=max(b.rm,b.sum+a.rm);
c.mx=max(a.rm+b.lm,max(a.mx,b.mx));
return c;
}
}
void change(int k,int x,ll v)
{
if(tr[k].l==tr[k].r){
tr[k].lm=tr[k].rm=tr[k].sum=tr[k].mx=v;
return ;
}
int mid=(tr[k].l+tr[k].r)>>1;
if(x<=mid)change(k<<1,x,v);
else change(k<<1|1,x,v);
update(k);
}
int main()
{
int n,q;
while(~scanf("%d",&n))
{
build(1,1,n);
scanf("%d",&q);
while(q--){
int a,b,c;
scanf("%d %d %d",&a,&b,&c);
if(a){
if(b>c)swap(b,c);
node tem=ask(1,b,c);
printf("%lld\n",tem.mx);
}
else {
change(1,b,c);
}
}
}
return 0;
}