对于线段树的讲解此篇不再赘述,下面列出线段树应用中最常用的几种操作的代码。(具体题目未贴出,仅供有一定基础者参考代码风格)
另外,注意多组输入要写scanf("%d%d",&n,&m)!=EOF,线段树的题肯定要用c语言的输入输出,要使用字符数组,不用字符串,输入字符的时候要加getchar()吞噬空行..
(1)单点增减,区间求和:
#include<iostream> #include<stdio.h> #include<string> #include<string.h> using namespace std; int a[50000*4];//开四倍即可 int sum[50000*4]; int T,n,b,c,ans; char s[20]; //单点增减,区间求和 int build(int l,int r,int o)//l左,r右,o结点 { if(l==r)//由上至下建树 return sum[ o]=a[l]; else { int mid=l+(r-l)/2; return sum[o]=build(l,mid,2*o)+build(mid+1,r,2*o+1); } } void query(int l,int r,int o)//b,c是要求和的区间 { if(b<=l&&c>=r)//注意是b,c包含l,r ans+=sum[o]; else { int mid=l+(r-l)/2; if(b<=mid) query(l,mid,o*2); if(c>mid) query(mid+1,r,2*o+1); } } void add(int l,int r,int o) { sum[o]+=c;//注意这个点要加的值 if(l==r) return; else { int mid=l+(r-l)/2; if(b<=mid) add(l,mid,o*2); else add(mid+1,r,o*2+1); } } int main() { int ca=1; scanf("%d",&T); while(T--) { scanf("%d",&n); for(int i=1;i<=n;i++) scanf("%d",&a[i]); printf("Case %d:\n",ca++); build(1,n,1); while(scanf("%s",s)) { if(s[0]=='E') break; if(s[0]=='A') { scanf("%d%d",&b,&c); add(1,n,1); } if(s[0]=='S') { scanf("%d%d",&b,&c); c=-c; add(1,n,1); } if(s[0]=='Q') { ans=0; scanf("%d%d",&b,&c); query(1,n,1); printf("%d\n",ans); } } } return 0;
int build(int l,int r,int o)//l左,r右,o结点 { if(l==r)//由上至下建树 return max1[o]=a[l];//最值建树 else { int mid=l+(r-l)/2; return max1[o]=max(build(l,mid,2*o),build(mid+1,r,2*o+1)); } } void update(int l,int r,int o)//单点更新 { // sum[o]+=c; max1[o]=max(max1[o],c);//从原来值和改变值之间取最大的 if(l==r) return; else { int mid=l+(r-l)/2; if(b<=mid) update(l,mid,o*2); else update(mid+1,r,o*2+1); } } void query(int l,int r,int o)//b,c是要求最值的区间 { if(b<=l&&c>=r)//注意是b,c包含l,r ans=max(ans,max1[o]);//取最大的,ans初值为0 //ans+=sum[o]; else { int mid=l+(r-l)/2; if(b<=mid) query(l,mid,o*2); if(c>mid) query(mid+1,r,2*o+1); } }
#include<iostream> #include<stdio.h> using namespace std; const int MAX_N=100100; int sum[MAX_N<<2],col[MAX_N<<2]; int n,q,x,y,c; void push_down(int l,int r,int o)//下放 { int m=(l+r)>>1; if(col[o]) { col[o<<1]=col[o<<1|1]=col[o]; sum[o<<1]=(m-l+1)*col[o]; sum[o<<1|1]=(r-m)*col[o]; col[o]=0; } } void build(int l,int r,int o) { col[o]=0; sum[o]=1; if(l==r) return ; int m=(l+r)>>1; build(l,m,o<<1); build(m+1,r,o<<1|1); sum[o]=sum[o<<1]+sum[o<<1|1]; } void update(int l,int r,int o) { int m=(l+r)>>1; if(x<=l&&r<=y) { col[o]=c; sum[o]=(r-l+1)*c; return ; } push_down(l,r,o); if(x<=m) update(l,m,o<<1); if(y>m) update(m+1,r,o<<1|1); sum[o]=sum[o<<1]+sum[o<<1|1]; } int main() { int T,cas=1; scanf("%d",&T); while(T--) { scanf("%d%d",&n,&q); build(1,n,1); while(q--) { scanf("%d%d%d", &x, &y, &c); update(1,n,1); } printf("Case %d: The total value of the hook is %d.\n", cas++, sum[1]); } return 0; }
#include <iostream> #include <cstdio> using namespace std; #define N 100005 __int64 sum[N << 2]; __int64 mark[N << 2]; __int64 basic_num[N<<2]; inline void pushUp (int o) { sum[o] = sum[o << 1] + sum[o << 1 | 1]; } void pushDown (int o, int len) { if (mark[o]) { //因为o的儿子节点可能被多次延迟标记,并且o的儿子节点的延迟标记没有向o的孙子节点移动,所以用“+=” mark[o << 1] += mark[o]; mark[o << 1 | 1] += mark[o]; /*此处用mark[o]乘以区间长度,不是mark[o<<1], 因为o的儿子节点如果被多次标记,之前被标记时, 就已经对sum[o<<1]更新过了。 */ sum[o << 1] += mark[o] * (len - (len >> 1)); sum[o << 1 | 1] += mark[o] * (len >> 1); mark[o] = 0; //将标记向儿子节点移动后,父节点的延迟标记去掉 } } void build (int l, int r, int o) { mark[o] = 0; if (l == r) { sum[o]=basic_num[l]; return; } int m = (l + r) >> 1; build (l, m, o << 1); build (m + 1, r, o << 1 | 1); pushUp (o);//sum值初始化 } void update (int L, int R, __int64 c, int l, int r, int o)//L,R是要更新的区间 { if (l >= L && r <= R) { mark[o] += c; sum[o] += c * (r - l + 1); return; } /*当要对被延迟标记过的这段区间的儿子节点进行更新时,先要将延迟标记向儿子节点移动 当然,如果一直没有对该段的儿子节点更新,延迟标记就不需要向儿子节点移动,这样就使 更新操作的时间复杂度仍为O(logn),也是使用延迟标记的原因。 */ pushDown(o, r - l + 1); //将已有的延迟下放 int m = (l + r) >> 1; if (m >= L) update (L, R, c, l, m, o << 1); if (m < R) update (L, R, c, m + 1, r, o << 1 | 1); pushUp (o); } __int64 query (int L, int R, int l, int r, int o) { if (l >= L && r <= R) return sum[o]; //要取o子节点的值时,也要先把o的延迟标记向下移动 pushDown(o, r - l + 1); int m = (l + r) >> 1; __int64 ret = 0; if (m >= L) ret += query (L, R, l, m, o << 1); if (m < R) ret += query (L, R, m + 1, r, o << 1 | 1); return ret; } int main() { int n, q, a, b; __int64 c; char op; scanf ("%d%d", &n,&q); for(int i=1;i<=n;i++) scanf("%I64d",&basic_num[i]); build (1, n, 1); while (q--) { getchar(); scanf ("%c %d %d", &op, &a, &b); if (op == 'Q') printf ("%I64d\n", query (a, b, 1, n, 1)); else { scanf ("%I64d", &c); update (a, b, c, 1, n, 1); } } return 0; }