上周的校赛打的好差好差啊,我都想放弃了。
不过既然实验室有各种竞赛大神讲算法,秉着学习的目的,我还是没退。。
为了不让才写了一篇的博客就这样断更。
所以我还是接着写下去吧。。
So,正题来了,线段树
刚看到线段树的结构时,我第一个想到的就是树状数组。因为它们不仅结构相似,算法思想差不多,而且实现的功能接近。
大概是这样一个东西,每个节点都代表一个区间,比如1节点代表2,3节点的区间和....而每个节点的左子节点和右子节点将其区间平分为两份。
类似树状数组,我们可以通过这个结构进行区间求和,更新,查询等方法;
我们可以观察出一个规律,对于x节点代表的区间[left , right],它的左子节点的区间为[left , mid],右子节点的区间为[mid+1 , right],
左子节点的标号为 x*2 右子节点的编号为x*2+1 ;
根据这个规律,我们可以很快写出一个建树方法
class node{
int left=0,right=0,s=0;
public node(int le,int ri){
this.left=le;
this.right=ri;
}
}
这是节点类
static int b[];//b[]存储的是需要建树的数组;
static node []tree=new node[200086];
static int build(int left,int right,int x){
tree[x]=new node(left,right);
if(left==right)return tree[x].s=b[left];//当递归到最底层时,也就是left==right时
else{
int mid =(left+right)>>1;
tree[x].s+=build(left,mid,x*2); //左子节点
tree[x].s+= build(mid+1,right,x*2+1);//右子节点
return tree[x].s;
}
}
建树方法
static void update(int x, int pos,int n){
if(tree[x].left==tree[x].right) tree[x].s+=n;//在递归前保证了pos在区间内,left==right时说明该点为pos点
else{
int mid =(tree[x].right+tree[x].left)>>1;
if(pos<=mid)update(2*x,pos,n); //若pos在左子节点的区间中
else update(2*x+1,pos,n); //若pos在右子节点的区间中
tree[x].s=tree[2*x].s+tree[2*x+1].s;
}
}
static int query(int x,int le,int ri){
if(tree[x].left==le&&tree[x].right==ri)return tree[x].s;
else{
int mid =(tree[x].right+tree[x].left)>>1;
if(ri<=mid)return query(2*x,le,ri);
else if(le>mid) return query(2*x+1,le,ri);
else return query(2*x,le,mid)+query(2*x+1,mid+1,ri);
}
}
区间查询
最基本的操作就这样实现了。
其中建树的时间复杂度为O(log N)每次查询操作最多也是O(log N);
有一个小细节是树状数组是从下到上更新,而线段树是从上到下更新
这就是很裸的线段树实现单点更新,区间查询的题目
树状数组做会非常简单,而且更快
AC代码如下
package SegmentTree;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.*;
class node{
int left=0,right=0,s=0;
public node(int le,int ri){
this.left=le;
this.right=ri;
}
}
public class Main {
static BufferedReader in=new BufferedReader(new InputStreamReader(System.in));
static StringTokenizer tok;
static boolean hasNext()
{
while(tok==null||!tok.hasMoreTokens())
try{
tok=new StringTokenizer(in.readLine());
}
catch(Exception e){
return false;
}
return true;
}
static String next()
{
hasNext();
return tok.nextToken();
}
static long nextLong()
{
return Long.parseLong(next());
}
static int nextInt()
{
return Integer.parseInt(next());
}
static PrintWriter out=new PrintWriter(new OutputStreamWriter(System.out));
static int b[];
static node []tree=new node[200086];
static int build(int left,int right,int x){
tree[x]=new node(left,right);
if(left==right)return tree[x].s=b[left];
else{
int mid =(left+right)>>1;
tree[x].s+=build(left,mid,x*2);
tree[x].s+= build(mid+1,right,x*2+1);
return tree[x].s;
}
}
static void update(int x, int pos,int n){
if(tree[x].left==tree[x].right) tree[x].s+=n;
else{
int mid =(tree[x].right+tree[x].left)>>1;
if(pos<=mid)update(2*x,pos,n);
else update(2*x+1,pos,n);
tree[x].s=tree[2*x].s+tree[2*x+1].s;
}
}
static int query(int x,int le,int ri){
if(tree[x].left==le&&tree[x].right==ri)return tree[x].s;
else{
int mid =(tree[x].right+tree[x].left)>>1;
if(ri<=mid)return query(2*x,le,ri);
else if(le>mid) return query(2*x+1,le,ri);
else return query(2*x,le,mid)+query(2*x+1,mid+1,ri);
}
}
public static void main(String[] args) {
// TODO Auto-generated method stub
//Scanner sc =new Scanner (System.in);
int T=nextInt();
for(int t=1;t<=T;t++){
int N=nextInt();
b=new int [N+1];
for(int i=1;i<=N;i++){
b[i]=nextInt();
}
build(1,N,1);
out.println("Case "+t+":");
out.flush();
while(true){
String str=next();
if(str.equals("End"))break;
if(str.equals("Query")){
int left=nextInt();
int right=nextInt();
out.println(query(1,left,right));
out.flush();
}
else if(str.equals("Add")){
int pos=nextInt();
int n=nextInt();
update(1,pos,n);
}
else if(str.equals("Sub")){
int pos=nextInt();
int n=nextInt();
update(1,pos,-n);
}
}
}
}
}
这题则是实现区间最大值查询和更新
JAVA莫名爆出MLE,改用C++写了一遍
代码如下
#include
#include
#include
#include
#include
#include
using namespace std;
const int maxn=200010;
int b[maxn];
struct node{
int left,right,max;
}tree[maxn*4];
int n,m;
static int build(int left,int right,int x){
tree[x].left=left;
tree[x].right=right;
if(left==right)return tree[x].max=b[left];
else{
int mid =(left+right)>>1;
return tree[x].max=max(build(left,mid,x*2), build(mid+1,right,x*2+1));
}
}
static int update(int x, int pos,int n){
if(tree[x].left==tree[x].right) return tree[x].max=n;
else{
int mid =(tree[x].right+tree[x].left)>>1;
if(pos<=mid){
return tree[x].max=max(tree[x].max, update(2*x,pos,n));
}
else {
return tree[x].max=max(tree[x].max,update(2*x+1,pos,n));
}
}
}
static int query(int x,int le,int ri){
if(tree[x].left==le&&tree[x].right==ri)return tree[x].max;
else{
int mid =(tree[x].right+tree[x].left)>>1;
if(ri<=mid)return query(2*x,le,ri);
else if(le>mid) return query(2*x+1,le,ri);
else {
int a=query(2*x,le,mid);
int b=query(2*x+1,mid+1,ri);
return max(a,b);
}
}
}
int main() {
while(~scanf("%d%d",&n,&m)){
for(int i=1;i<=n;i++)scanf("%d",&b[i]);
build(1,n,1);
while(m--){
int a,b;char s[10];
scanf("%s%d%d",&s,&a,&b);
if(s[0]=='Q'){
printf("%d\n",query(1,a,b));
}
else if(s[0]=='U'){
update(1,a,b);
}
}
}
return 0;
}