概念:线段树就是一棵二叉树,每个节点代表一个区间,主要用于解决区间类问题。每个节点的属性根据需要可以去自定义,比如节点的属性可以是区间和、区间最大/最小值。。
每个node有区间的左右端点,以及左右孩子
public class SegmentTreeNode{
int start,end,val; //val根据需要定义,比如我定义为区间最大值就是max
SegmentTreeNode left,right;
public SegmentTreeNode(int start, int end, int val){
this.start = start;
this.end = end;
this.val = val;
this.left = this.right = null;
}
}
自上而下,分治法,递归调用。
对于区间[m1,m2],mid = (m1+m2)/2,其左儿子区间是[m1,mid],右儿子区间是(mid+1,m2)
//线段树的构建,以求区间最大值为例,返回根节点
public SegmentTreeNode build(int start, int end, int[] A) {
if (start > end) {
return null;
}
if (start == end) {
return new SegmentTreeNode(start, end, A[start]);
}
//先new根区间,根区间最大值暂时为A[start],不能为其他乱七八糟的值比如-1这种...
SegmentTreeNode root = new SegmentTreeNode(start, end, A[start]);
if (start != end) {
int mid = (start + end) / 2;
root.left = build(start, mid, A);
root.right = build(mid + 1, end, A);
}
//改root的val
if (root.left != null && root.left.max > root.max) {
root.max = root.left.max;
}
if (root.right != null && root.right.max > root.max) {
root.max = root.right.max;
}
return root;
}
递归调用,一路向下找到最小区间,触底反弹的时候才去修改node。比如数组[6,3,5,1,9],我要修改1位置上的3,那就是一路向下先找到3,然后返回途中修改 ,时间复杂度logN
//线段树的修改,
public void modify(SegmentTreeNode root, int index, int value) {
if (root.start == index && root.end == index) {
root.max = value;
return;
}
int mid = (root.start + root.end) / 2;
//看index在左区间还是右区间
if (root.start <= index && index <= mid) {
modify(root.left, index, value);
}
if (mid < index && index <= root.end) {
modify(root.right, index, value);
}
//最后改下根
root.max = Math.max(root.left.max, root.right.max);
}
比如上面例子,找[0,3],先找[0,2],再找[3,3]。找[0,2],直接返回,找[3-3],就需要走到底
//线段树的查询
public int query(SegmentTreeNode root, int start, int end) {
if (start == root.start && end == root.end) {
return root.max;
}
int mid = (root.start + root.end) / 2;
int left_max = Integer.MIN_VALUE;
int right_max = Integer.MIN_VALUE;
//求左边最大值
//如果给定查询范围起点在左子树
if (start <= mid) {
//但是终点在右子树(横跨左子树和右子树),那么左边最大值就在start到mid之间查询
if (mid < end) {
left_max = query(root.left, start, mid);
} else { //如果只在左子树
left_max = query(root.left, start, end);
}
}
//求右边最大值
if (mid < end) {
//横跨左右子树的情况,起点为mid+1
if (start <= mid) {
right_max = query(root.right, mid + 1, end);
} else { //如果只在右子树
right_max = query(root.right, start, end);
}
}
return Math.max(left_max, right_max);
}
对于区间[m1,m2],mid = (m1+m2)/2,其左儿子区间是[m1,mid],右儿子区间是(mid+1,m2)
区间求和
给定一个整数数组(下标由 0 到 n-1,其中 n 表示数组的规模),以及一个查询列表。每一个查询列表有两个整数
[start, end]
。 对于每个查询,计算出数组中从下标 start 到 end 之间的数的总和输入: 数组 :[1,2,7,8,5], 查询:[(0,4),(1,2),(2,4)] 输出: [23,9,20]
思路
首先定义Node和树
//线段树Node
class SegmentTreeNode{
int start;
int end;
long sum;
SegmentTreeNode left, right;
public SegmentTreeNode(int start, int end) {
this.start = start;
this.end = end;
sum = 0;
left = right = null;
}
}
class SegmentTree{
private int size; //区间
private SegmentTreeNode root;
public SegmentTree(int[] A) {
size = A.length;
root = buildTree(A,0, size - 1);
}
private SegmentTreeNode buildTree(int[] A, int start, int end) {
SegmentTreeNode node = new SegmentTreeNode(start, end);
//递归出口,叶子节点
if (start == end) {
node.sum = A[start];
return node;
}
//不是出口,递归建立左子树和右子树
int mid = (start + end) / 2;
node.left = buildTree(A, start, mid);
node.right = buildTree(A, mid + 1, end);
//别忘记维护当前节点的sum
node.sum = node.left.sum + node.right.sum;
return node;
}
//查询对外界接口
public long querySum(int start, int end) {
return querySum(root, start, end);
}
//重载方法.在node节点下查询原数组start到end区间内的和
private long querySum(SegmentTreeNode node, int start, int end) {
//递归出口
if (node.start == start && node.end == end) {
return node.sum;
}
int mid = (node.start + node.end) / 2;
long leftSum = 0, rightSum = 0;
//左边区间
if (start <= mid) {
//如果不是跨区间
if (end <= mid) {
leftSum = querySum(node.left, start, end);
} else {
leftSum = querySum(node.left, start, mid);
}
//可以合并为一行 !!!!
// leftSum = querySum(node.left, start, Math.min(mid, end));
}
//要考虑右半区间,也就是start-end与右半区间有交集
if (end >= mid + 1) {
//如果不是跨区间
if (start >= mid + 1) {
rightSum = querySum(node.right, start, end);
} else {
rightSum = querySum(node.right, mid + 1, end);
}
// 可以合并为一句
// rightSum = querySum(node.right, Math.max(mid + 1, start), end);
}
return leftSum + rightSum;
}
}
实现方法
public class Interval {
int start, end;
Interval(int start, int end) {
this.start = start;
this.end = end;
}
}
public List<Long> intervalSum(int[] A, List<Interval> queries) {
List<Long> res = new ArrayList<>();
SegmentTree segmentTree = new SegmentTree(A);
for (Interval query : queries) {
long sum = segmentTree.querySum(query.start, query.end);
res.add(sum);
}
return res;
}
在类的构造函数中给一个整数数组, 实现两个方法
query(start, end)
和modify(index, value)
:
- 对于 query(start, end), 返回数组中下标 start 到 end 的 和。
- 对于 modify(index, value), 修改数组中下标为 index 上的数为 value.
比206多了modify,无法使用前缀和数组,暴力O(nm),线段树/树状数组O(mlogn),n为数组长度,m为操作次数。
线段树类中提供三个方法
public class Solution {
private SegmentTree segmentTree;
public Solution(int[] A) {
if (A == null || A.length == 0) {
return;
}
segmentTree = new SegmentTree(A);
}
public long query(int start, int end) {
return segmentTree.querySum(start, end);
}
public void modify(int index, int value) {
segmentTree.modify(index, value);
return;
}
}
class SegmentTreeNode{
public int start, end;
public long sum;
public SegmentTreeNode left, right;
public SegmentTreeNode(int start, int end) {
this.start = start;
this.end = end;
sum = 0;
left = right = null;
}
}
class SegmentTree{
public SegmentTreeNode root;
public int size;
public SegmentTree(int[] A) {
size = A.length;
root = buildTree(A, 0, size - 1);
}
private SegmentTreeNode buildTree(int[] A, int start, int end) {
SegmentTreeNode node = new SegmentTreeNode(start, end);
//递归出口
if (start == end) {
node.sum = A[start];
return node;
}
//不是出口则递归建所有子树
int mid = (start + end) / 2;
node.left = buildTree(A, start, mid);
node.right = buildTree(A, mid + 1, end);
node.sum = node.left.sum + node.right.sum;
return node;
}
//公开接口
public long querySum(int start, int end) {
return querySum(root, start, end);
}
//公开接口
public void modify(int index, int val) {
modify(root, index, val);
}
private long querySum(SegmentTreeNode node, int start, int end) {
if (node.start == start && node.end == end) {
return node.sum;
}
int mid = (node.start + node.end) / 2; //这边不是start和end 是node的区间
long leftSum = 0, rightSum = 0;
if (start <= mid) {
leftSum = querySum(node.left, start, Math.min(end, mid));
}
if (end >= mid + 1) {
rightSum = querySum(node.right, Math.max(start, mid + 1), end);
}
return leftSum + rightSum;
}
private void modify(SegmentTreeNode node, int index, int val) {
//递归出口:到达这个叶子节点,并修改它的值
if (node.start == node.end && node.end == index) {
node.sum = val;
return;
}
//递归:分为在左子树和右子树两种情况,不用求mid
if (node.left.end >= index) {
modify(node.left, index, val);
} else {
modify(node.right, index, val);
}
//最后改下根
node.sum = node.left.sum + node.right.sum;
}
}
统计比给定整数小的数的个数
给定一个整数数组 (下标由 0 到 n-1,其中 n 表示数组的规模,数值范围由 0 到 10000),以及一个查询列表。对于每一个查询,将会给你一个整数,请你返回该数组中小于给定整数的元素的数量。
输入: array =[1,2,7,8,5] queries =[1,8,5] 输出:[0,4,2]
时间复杂度:
线段树思路:
public class Solution {
public List<Integer> countOfSmallerNumber(int[] A, int[] queries) {
int[] B = new int[10001];
for (int i : A) {
B[i]++;
}
//建立线段树,大小为10001
SegmentTree tree = new SegmentTree(10001);
for (int i = 0; i < 10001; i++) {
tree.modify(i, B[i]); //i位置修改为B[i],表示这个数出现了多少次
}
List<Integer> res = new ArrayList<>();
for (int i : queries) {
if (i == 0) {
res.add(0); //没有数比0小,都是正数
} else {
res.add(tree.querySum(0, i - 1));
}
}
return res;
}
}
class SegmentTreeNode{
public int sum;
public int start, end;
public SegmentTreeNode left, right;
public SegmentTreeNode(int start, int end) {
this.start = start;
this.end = end;
sum = 0;
left = right = null;
}
}
class SegmentTree{
private int size;
private SegmentTreeNode root;
public SegmentTree(int size) {
this.size = size;
root = buildTree(0, size - 1);
}
//初始化得到的是全0的树
private SegmentTreeNode buildTree(int start, int end) {
SegmentTreeNode node = new SegmentTreeNode(start, end);
if (start == end) {
return node;
}
int mid = (start + end) / 2;
node.left = buildTree(start, mid);
node.right = buildTree(mid + 1, end);
return node;
}
public int querySum(int start, int end) {
return querySum(root, start, end);
}
//在node节点的子树下,查询[start,end]区间内维护的和
public int querySum(SegmentTreeNode node, int start, int end) {
if (node.start == start && node.end == end) {
return node.sum;
}
int leftSum = 0, rightSum = 0;
int mid = (node.start + node.end) / 2;
if (start <= mid) {
leftSum = querySum(node.left, start, Math.min(end, mid));
}
if (end >= mid + 1) {
rightSum = querySum(node.right, Math.max(start, mid + 1), end);
}
return leftSum + rightSum;
}
public void modify(int index, int val) {
modify(root, index, val);
}
private void modify(SegmentTreeNode node, int index, int val) {
if (node.start == node.end && node.end == index) { //可以省略node.end == index
node.sum = val;
return;
}
if (node.left.end >= index) {
modify(node.left, index, val);
} else {
modify(node.right, index, val);
}
//维护当前节点sum
node.sum = node.left.sum + node.right.sum;
}
}
统计前面比自己小的数的个数
给定一个整数数组(下标由 0 到 n-1, n 表示数组的规模,取值范围由 0 到10000)。对于数组中的每个
ai
元素,请计算ai
前的数中比它小的元素的数量。输入: [1,2,7,8,5] 输出: [0,1,2,3,2]
时间复杂度:
思路:
A=[1,2,7,8,5]
B=[0,0,0,0,0,0,0,0,0] 初始,这里B得开9,因为0~8一共9位
B=[0,1,0,0,0,0,0,0,0] B[1]++,统计比A中第二个元素(2)小的个数,B[0]+B[1]
B=[0,1,1,0,0,0,0,0,0] B[2]++,A中第三个元素为7,计算B[0]+…+B[6]
B=[0,1,1,0,0,0,0,1,1] B[7]++,A中第四个元素为8,计算B[0]+…+B[7]
B=[0,1,1,0,0,1,0,1,1] B[8]++,A中第五个元素为5,计算B[0]+…+B[4]
public class Solution {
public List<Integer> countOfSmallerNumberII(int[] A) {
List<Integer> res = new ArrayList<>();
SegmentTree tree = new SegmentTree(10001);
int[] B = new int[10001];
for (int i : A) {
if (i == 0) {
res.add(0);
} else {
res.add(tree.querySum(0, i - 1));
}
//更新B
B[i]++;
tree.modify(i, B[i]);
}
return res;
}
}
class SegmentTreeNode{
public int sum;
public int start, end;
public SegmentTreeNode left, right;
public SegmentTreeNode(int start, int end) {
this.start = start;
this.end = end;
sum = 0;
left = right = null;
}
}
class SegmentTree{
private int size;
private SegmentTreeNode root;
public SegmentTree(int size) {
this.size = size;
root = buildTree(0, size - 1);
}
//初始化得到的是全0的树
private SegmentTreeNode buildTree(int start, int end) {
SegmentTreeNode node = new SegmentTreeNode(start, end);
if (start == end) {
return node;
}
int mid = (start + end) / 2;
node.left = buildTree(start, mid);
node.right = buildTree(mid + 1, end);
return node;
}
public int querySum(int start, int end) {
return querySum(root, start, end);
}
//在node节点的子树下,查询[start,end]区间内维护的和
public int querySum(SegmentTreeNode node, int start, int end) {
if (node.start == start && node.end == end) {
return node.sum;
}
int leftSum = 0, rightSum = 0;
int mid = (node.start + node.end) / 2;
if (start <= mid) {
leftSum = querySum(node.left, start, Math.min(end, mid));
}
if (end >= mid + 1) {
rightSum = querySum(node.right, Math.max(start, mid + 1), end);
}
return leftSum + rightSum;
}
public void modify(int index, int val) {
modify(root, index, val);
}
private void modify(SegmentTreeNode node, int index, int val) {
if (node.start == node.end && node.end == index) { //可以省略node.end == index
node.sum = val;
return;
}
if (node.left.end >= index) {
modify(node.left, index, val);
} else {
modify(node.right, index, val);
}
//维护当前节点sum
node.sum = node.left.sum + node.right.sum;
}
}