线段树具有良好的二分结构,支持logN的插入,查找,查询操作,是处理区间问题的利器。
近段时间我将会逐渐完成一位HDU的大牛的线段树专辑的部分题目与其它线段树题目以巩固线段树的学习,并把代码发上来。
因为HDU题库交Pascal的程序总是CE,所以大部分是用C语言写的。
/* 题意:给出n个兵营的人数,给出若干命令,包括增减某个兵营的人数和询问编号在[x..y]区间内的兵营的总人数有多少 算法:更新节点,区间求和,虽然可以用树状数组过,但也可练练线段树 体会: 线段树各过程中变量不要敲错! */ #include <stdio.h> #include <string.h> #define root 1 #define MAXN 200000 struct node { int left,right,num; }t[MAXN]; int a[MAXN]; int build(int p, int left, int right) { t[p].left = left; t[p].right = right; if (left == right) return t[p].num = a[left]; int m = (left + right) / 2; return t[p].num = build(p*2, left, m) + build(p*2+1, m+1, right); } int query(int p, int left ,int right) { int m = (t[p].left + t[p].right) / 2; if ((t[p].left == left) && (t[p].right == right)) return t[p].num; else if (right <= m) return query(p*2, left, right); else if (m < left) return query(p*2+1,left, right); else return query(p*2, left, m)+query(p*2+1, m+1, right); } void update(int p,int x,int y) { t[p].num += y; if (t[p].left == t[p].right) return; else { int m = (t[p].left + t[p].right) / 2; if (x <= m) update(p*2,x,y); else update(p*2+1,x,y); } } int main() { int t,i,j,k,m,n,x,y; scanf("%d", &t); for (k = 1; k <= t; k++) { printf("Case %d:\n",k); scanf("%d", &n); for (i = 1; i <= n; i++) scanf("%d", &a[i]); build(root,1,n); char cmd[10]; while (scanf("%s",cmd)) { if ((strcmp(cmd, "End") == 0)) break; scanf("%d%d", &x, &y); switch (cmd[0]) { case 'Q': printf("%d\n",query(root,x,y)); break; case 'A': update(root,x,y); break; case 'S': update(root,x,-y); break; } } } return 0; }
/* 题意:给出n个学生的成绩,老师会有m条指令:查询区间内的最高成绩和修改某个同学的成绩。 算法:线段树,节点更新,区间最值 体会:没说,类似的PASCAL写过,tyvj P1039忠诚2,几乎一样。不过这次用C独立完成,20分钟左右敲完了,还算可以。另外因为习惯一个函数名打错了,于是就define了一下 */ #include <stdio.h> #define MAXN 1000000 #define root 1 #define query getmax struct node { int left,right,max; }t[MAXN]; int a[MAXN]; int max(int a,int b) { return a>b ? a : b; } int build(int p,int left,int right) { t[p].left = left; t[p].right = right; if (left == right) return t[p].max = a[left]; int m = (left + right) / 2; return t[p].max = max(build(p*2, left, m), build(p*2+1, m+1,right)); } int getmax(int p,int left, int right) { if (t[p].left == left && t[p].right == right) return t[p].max; int m = (t[p].left + t[p].right) / 2; if (right <= m) return (query(p*2, left, right)); if (left > m) return (query(p*2+1, left, right)); return max(query(p*2,left,m),query(p*2+1,m+1,right)); } void change(int p, int goal, int grade) { if (t[p].left == t[p].right) t[p].max = grade; else { int m = (t[p].left + t[p].right) / 2; if (goal <= m) { change(p*2, goal, grade); t[p].max = max(t[p].max,t[p*2].max); } else { change(p*2+1, goal, grade); t[p].max = max(t[p].max,t[p*2+1].max); } } } int main() { int i,j,k,m,n,x,y; while (scanf("%d%d", &n, &m) == 2) { for (i = 1; i <= n; i++) scanf("%d", &a[i]); build(root,1,n); for (i = 1; i <= m; i++) { char cmd[10]; scanf("%s",cmd); scanf("%d%d", &x, &y); if (cmd[0] == 'Q') printf("%d\n",getmax(root,x,y)); else change(root, x, y); } } return 0; }
/* 题意:一个由0..n-1组成的序列,每次可以把队首的元素移到队尾, 求形成的n个序列最小逆序对数目 算法:将元素依次插入线段树,每次增加的逆序对数为比它大的已经插入的 数的个数,可以用线段树维护,由于元素值为0..n,每次移动可求出增减 逆序对的数量更新。 体会:线段树敲了20分钟,如果是PASCAL并且写得难看些10分钟应该能搞定。 Debug 10分钟,主要是更新时没考虑清楚+记错题目求成最大逆序对数了 */ #include <stdio.h> #define MAXN 100000 #define ROOT 1 struct node { int left,right,sum; }t[MAXN]; int val[MAXN]; int n; void build(int p, int left, int right) { int m; t[p].left = left; t[p].right = right; t[p].sum = 0; if (left == right) return; m = (left + right) / 2; build(p*2, left, m); build(p*2+1, m+1, right); } void update(int p, int goal, int add) { t[p].sum += add; if (t[p].left == t[p].right) return; int m = (t[p].left + t[p].right) / 2; if (goal <= m) update(p*2, goal, add); if (goal > m) update(p*2+1, goal, add); } int getsum(int p, int left, int right) { if (left > right) return 0; if (t[p].left == left && t[p].right == right) return t[p].sum; int m = (t[p].left + t[p].right) / 2; if (right <= m) return getsum(p*2, left, right); else if (left > m) return getsum(p*2+1, left, right); else return getsum(p*2, left, m) + getsum(p*2+1, m + 1, right); } int main() { while (scanf("%d", &n) == 1) { build(ROOT, 0, n - 1); int i,sum = 0,ans; for (i = 1; i <= n; i++) { scanf("%d", &val[i]); sum += getsum(ROOT, val[i], n - 1); update(ROOT, val[i], 1); } ans = sum; for (i = 1; i <= n; i++) { sum = sum + (n - val[i] - 1) - val[i]; ans = sum < ans ? sum : ans; } printf("%d\n", ans); } return 0; }