学习线段树~~~~~~~~~~~~要好好理解
此题是单点更新的线段树,考虑基本的询问,更新。
#include <iostream> #include <algorithm> #include <cmath> #include <cstdio> #include <cstdlib> #include <cstring> #include <string> #include <vector> #include <set> #include <queue> #include <stack> #include <climits>//形如INT_MAX一类的 #define MAX 50005 #define INF 0x7FFFFFFF #define REP(i,s,t) for(int i=(s);i<=(t);++i) #define ll long long #define mem(a,b) memset(a,b,sizeof(a)) #define mp(a,b) make_pair(a,b) # define eps 1e-5 //#pragma comment(linker, "/STACK:36777216") ///传说中的外挂 using namespace std; struct node { int left,right,mid,value,cover; } edge[4*MAX]; void build(int l,int r,int num) //l,r为当前节点的左右端点,num为当前节点在数组中的下标 { edge[num].left = l; edge[num].right = r; edge[num].value = 0; edge[num].mid = (l+r) >> 1; if(l != r) // 如果不是叶子节点 { build(l,edge[num].mid,num * 2); //左子树 build(edge[num].mid + 1,r,num*2+1); //右子树 } } void update(int num,int mid,int v) //单点更新 mid为更新点,v为更新值 { if(edge[num].left == mid && edge[num].right == mid) { edge[num].value += v; return ; } if(edge[num].mid >= mid) update(2*num,mid,v); else update(2*num+1,mid,v); edge[num].value = edge[num*2].value + edge[num*2+1].value; // 注意是‘=’不是‘+=’,更新值 } int query(int l,int r,int num) { if(l <= edge[num].left && r >= edge[num].right) { return edge[num].value; } if(r <= edge[num].mid) return query(l,r,num*2); else if(l >= edge[num].mid + 1) return query(l,r,num *2+1); else { return query(l,edge[num].mid,num*2) + query(edge[num].mid+1,r,num*2+1); } } /* void insert(int l,int r ,int num) //l,r为插入段的左右端点,num为当前结点的下标 { if(l <= edge[num].left && r >= edge[num].right) //插入段完全大于当前段 { edge[num].cover = 1; return ; } if(r <= edge[num].mid) //插入段被当前段的左子树包含 insert(l,r,2 * num); else if(l >= edge[num].mid + 1)//插入段被当前段的右子树包含 insert(l,r,2*num+1); else { insert(l,edge[num].mid, 2* num); insert(edge[num].mid+1,r,2*num+1); } } */ int main() { int t,i,j; int casee = 1; char str[10]; cin >> t; while(t--) { printf("Case %d:\n",casee++); int n,a; cin >> n; build(1,n,1); for(i=1; i<=n; i++) { scanf("%d",&a); update(1,i,a); } while(scanf("%s",str)) { if(str[0] == 'E') break; if(str[0] == 'Q') { scanf("%d%d",&i,&j); printf("%d\n",query(i,j,1)); } if(str[0] == 'A') { scanf("%d%d",&i,&j); update(1,i,j); } if(str[0] == 'S') { scanf("%d%d",&i,&j); update(1,i,-j); } } } return 0; }
成段更新版:
#include <iostream> #include <algorithm> #include <cmath> #include <cstdio> #include <cstdlib> #include <cstring> #include <string> #include <vector> #include <set> #include <queue> #include <stack> #include <climits>//形如INT_MAX一类的 #define MAX 50005 #define INF 0x7FFFFFFF #define REP(i,s,t) for(int i=(s);i<=(t);++i) #define ll long long #define mem(a,b) memset(a,b,sizeof(a)) #define mp(a,b) make_pair(a,b) #define L(x) x<<1 #define R(x) x<<1|1 # define eps 1e-5 //#pragma comment(linker, "/STACK:36777216") ///传说中的外挂 using namespace std; struct node { int left,right,mid,value,add; } edge[4*MAX]; int aa[MAX]; void push_up(int x) { edge[x].value = edge[x << 1].value + edge[x << 1 | 1].value ; } void build(int l,int r,int num) { //l,r为当前节点的左右端点,num为当前节点在数组中的下标 edge[num].left = l; edge[num].right = r; edge[num].mid = (l+r) >> 1; edge[num].add = 0; if(l == r) { edge[num].value = aa[l]; return ; } // 如果不是叶子节点 build(l,edge[num].mid,num * 2); //左子树 build(edge[num].mid + 1,r,num*2+1); //右子树 push_up(num); } void push_down(int x) { if(edge[x].add) { edge[x << 1].value += (edge[x << 1].right - edge[x << 1].left + 1 ) * edge[x].add ; edge[x << 1 | 1].value += (edge[x << 1 | 1].right - edge[x << 1 | 1].left + 1) * edge[x].add ; edge[x << 1].add += edge[x].add ; edge[x << 1 | 1].add += edge[x].add ; edge[x].add = 0 ; } } void update(int l, int r, int k, int num) { if(edge[num].left >= l && edge[num].right <= r) { //当前区间包含于更新区间 edge[num].add += k; edge[num].value += (edge[num].right - edge[num].left + 1) * k; return; } push_down(num); if(edge[num].mid < l) update(l, r, k, num*2+1); else if(edge[num].mid >= r) update(l, r, k, num*2); else { update(l, edge[num].mid, k, num*2); update(edge[num].mid + 1, r, k, num*2+1); } push_up(num); } int query(int l,int r,int num) { if(l <= edge[num].left && r >= edge[num].right) { return edge[num].value; } //push_down(num); if(r <= edge[num].mid) return query(l,r,num*2); else if(l >= edge[num].mid + 1) return query(l,r,num *2+1); else { return query(l,edge[num].mid,num*2) + query(edge[num].mid+1,r,num*2+1); } } int main() { int t,a,b; char str[10]; cin >> t; int casee = 1; while(t --) { printf("Case %d:\n",casee++); int n,i; cin >> n; for(i=1; i<=n; i++) scanf("%d",&aa[i]); build(1,n,1); while(scanf("%s",str)) { if(str[0] == 'E') { break; } if(str[0] == 'Q') { scanf("%d%d",&a,&b); printf("%d\n",query(a,b,1)); } if(str[0] == 'A') { scanf("%d%d",&a,&b); update(a,a,b,1); } if(str[0] =='S') { scanf("%d%d",&a,&b); update(a,a,-b,1); } } } return 0; }