这篇文章来讲二分法,这是一种在实际情况中十分常用的算法
我们之前讲过,解决计算机问题的一个常规方案就是暴力搜索,即:遍历整个搜索空间,找到给定问题的解
在这个基础上,针对问题的不同特征,我们可以应用不同的数据结构和算法,去优化搜索的时间和空间效率
二分搜索算法就是针对有序区间的元素搜索问题进行的时间效率优化
换句话说,区间有序是应用二分搜索算法的必要条件
如果要求在乱序数组中找到给定值,那么唯一的做法就是逐个遍历数组元素
如果要求在有序数组中找到给定值,那么就可以使用二分搜索算法进行处理
二分搜索算法的思路很简单:通过比较区间中值和给定值,每次可以缩小一半搜索区间
举个例子,给定有序数组 [1, 2, 3, 4, 5, 6, 7, 8, 9]
,要求找出元素 6
,算法步骤如下:
5
和给定值 6
,发现 5 < 6
,所以给定值必在区间右半部分 [6, 7, 8, 9]
7
和给定值 6
,发现 7 > 6
,所以给定值比在区间左半部分 [6]
6
和给定值 6
,发现 6 = 6
,至此就能找出给定值二分搜索算法的思路很简单,但实现起来需要注意的细节却有很多
下面先按上述所说思路给出一个代码模板,该模板用于在升序数组中查找给定值
如果能够找到,则返回对应下标;如果没有找到,则返回 -1
int binary_search(vector<int>& nums, int target) {
int n = nums.size();
// 定义搜索区间为 [p1 , p2]
int p1 = 0;
int p2 = n - 1;
// 定义终止条件为 (p1 > p2)
while (p1 <= p2) {
// 计算中值:
// 一般计算中值的方式是 (p1 + p2) / 2
// 为了防止相加溢出改成 ⬇
int mid = p1 + (p2 - p1) / 2;
// 分类讨论:
// 若中值等于给定值,则表示已找到,返回结果就好
// 若中值小于给定值,则将搜索区间约束在右半部分
// 若中值大于给定值,则将搜索区间约束在左半部分
if (nums[mid] == target) {
// 中值等于给定值
// 返回结果
return mid;
} else if (nums[mid] < target) {
// 中值小于给定值
// 更新搜索区间为右半部分,此时左边界更新,右边界不变
p1 = mid + 1;
} else if (nums[mid] > target) {
// 中值大于给定值
// 更新搜索区间为左半部分,此时右边界更新,左边界不变
p2 = mid - 1;
}
}
// 没有找到,返回结果
return -1;
}
上述的代码思路很清晰,但实际写起来可能会有很多细节值得注意
常见的问题集中在:搜索区间的定义、终止条件的定义、搜索区间的更新
搜索区间的定义,第 4
、5
行
这里定义的搜索区间是左闭右闭区间 [p1, p2]
,初始化为 p1 = 0, p2 = n - 1
当然也有其他人定义为左闭右开区间 [p1, p2)
,初始化为 p1 = 0, p2 = n
这两种定义方式都是没有问题的,区别在于后续要怎么定义终止条件和更新搜索区间
终止条件的定义,第 7
行
当搜索区间为空时,就可以终止搜索,具体来说:
对于 [p1, p2]
,区间为空时有 p1 > p2
,也即 while
运行条件应为 p1 <= p2
对于 [p1, p2)
,区间为空时有 p1 >= p2
,也即 while
运行条件应为 p1 < p2
搜索区间的更新,第 23
、27
行
判断完中值和给定值的大小关系后,新的搜索区间应该去掉中值分为左右两部分,具体来说:
对于 [p1, p2]
,新的搜索区间应为 [p1, mid - 1]
和 [mid + 1, p2]
对于 [p1, p2)
,新的搜索区间应为 [p1, mid)
和 [mid + 1, p2)
下面针对三种常见的二分搜索场景给出代码模板,以后遇到相似的场景时可以举一反三地解决问题
在升序数组中查找给定值唯一出现的位置,降序数组思路类似,可以自行推理
如果能够找到,则返回对应下标;如果没有找到,则返回 -1
int binary_search(vector<int>& nums, int target) {
int n = nums.size();
int p1 = 0;
int p2 = n - 1;
while (p1 <= p2) {
int mid = p1 + (p2 - p1) / 2;
if (nums[mid] == target) {
// 若中值等于给定值,则表示已找到,返回结果就好
return mid;
} else if (nums[mid] < target) {
// 若中值小于给定值,则将搜索区间约束在右半部分
p1 = mid + 1;
} else if (nums[mid] > target) {
// 若中值大于给定值,则将搜索区间约束在左半部分
p2 = mid - 1;
}
}
return -1;
}
在升序数组中查找给定值最先出现的位置,降序数组思路类似,可以自行推理
如果能够找到,则返回对应下标;如果没有找到,则返回 -1
int lower_bound(vector<int>& nums, int target) {
int n = nums.size();
int p1 = 0;
int p2 = n - 1;
while (p1 <= p2) {
int mid = p1 + (p2 - p1) / 2;
if (nums[mid] == target) {
// 不同之处
// 若中值等于给定值,则将搜索范围约束在左半部分继续搜索
p2 = mid - 1;
} else if (nums[mid] < target) {
// 若中值小于给定值,则将搜索区间约束在右半部分
p1 = mid + 1;
} else if (nums[mid] > target) {
// 若中值大于给定值,则将搜索区间约束在左半部分
p2 = mid - 1;
}
}
// 不同之处
// 最后检查 p1 是否符合条件
if (p1 > n - 1 || nums[p1] != target) {
return -1;
}
return p1;
}
在升序数组中查找给定值最后出现的位置,降序数组思路类似,可以自行推理
如果能够找到,则返回对应下标;如果没有找到,则返回 -1
int upper_bound(vector<int>& nums, int target) {
int n = nums.size();
int p1 = 0;
int p2 = n - 1;
while (p1 <= p2) {
int mid = p1 + (p2 - p1) / 2;
if (nums[mid] == target) {
// 不同之处
// 若中值等于给定值,则将搜索范围约束在右半部分继续搜索
p1 = mid + 1;
} else if (nums[mid] < target) {
// 若中值小于给定值,则将搜索区间约束在右半部分
p1 = mid + 1;
} else if (nums[mid] > target) {
// 若中值大于给定值,则将搜索区间约束在左半部分
p2 = mid - 1;
}
}
// 不同之处
// 最后检查 p2 是否符合条件
if (p2 < 0 || nums[p2] != target) {
return -1;
}
return p2;
}
(1)在有序数组中查找给定值可插入位置 | leetcode35
给定一个已排序数组和一个目标值,不考虑重复元素
如果目标值在数组内,返回其索引,如果目标值不在数组内,返回其按顺序插入的位置
解题思路,可转换为在有序数组中查找第一个大于等于给定值的位置
class Solution {
public:
int searchInsert(vector<int>& nums, int target) {
int n = nums.size();
int p1 = 0;
int p2 = n - 1;
while (p1 <= p2) {
int mid = p1 + (p2 - p1) / 2;
if (nums[mid] == target) {
p2 = mid - 1;
} else if (nums[mid] < target) {
p1 = mid + 1;
} else if (nums[mid] > target) {
p2 = mid - 1;
}
}
return p1;
}
};
(2)在有序数组中查找给定值出现的范围 | leetcode34
给定一个已排序数组和一个目标值,数组中有重复的元素
找出目标值在数组中的开始位置和结束位置,如果目标值不在数组内,则返回 [-1, -1]
解题思路,可转换为在有序数组中查找给定值最先出现的位置和在有序数组中查找给定值最后出现的位置
class Solution {
public:
vector<int> searchRange(vector<int>& nums, int target) {
int p1 = lower_bound(nums, target);
int p2 = upper_bound(nums, target);
return (p1 == -1 || p2 == -1) ? vector<int>{-1, -1} : vector<int>{p1, p2};
}
int lower_bound(vector<int>& nums, int target) {
int n = nums.size();
int p1 = 0;
int p2 = n - 1;
while (p1 <= p2) {
int mid = p1 + (p2 - p1) / 2;
if (nums[mid] == target) {
p2 = mid - 1;
} else if (nums[mid] < target) {
p1 = mid + 1;
} else if (nums[mid] > target) {
p2 = mid - 1;
}
}
if (p1 > n - 1 || nums[p1] != target) {
return -1;
}
return p1;
}
int upper_bound(vector<int>& nums, int target) {
int n = nums.size();
int p1 = 0;
int p2 = n - 1;
while (p1 <= p2) {
int mid = p1 + (p2 - p1) / 2;
if (nums[mid] == target) {
p1 = mid + 1;
} else if (nums[mid] < target) {
p1 = mid + 1;
} else if (nums[mid] > target) {
p2 = mid - 1;
}
}
if (p2 < 0 || nums[p2] != target) {
return -1;
}
return p2;
}
};
(3)吃香蕉 | leetcode875
给定一个数组,数组中的元素 piles[i]
表示第 i
堆香蕉的数量,单位为根
返回能在给定时间 h
内吃完所有香蕉的最小速度 k
,其中 h
的单位为时,k
的单位为根/时
在每小时中,只能选择一堆香蕉,如果这堆香蕉小于 k
根,吃完后这小时也不能吃别的香蕉
解题思路
- 吃香蕉的速度和能否在给定时间内吃完所有香蕉存在单调性,可以考虑用二分查找搜索最小速度
- 另一个问题是给定一个速度,怎么计算需要多少时间才能吃完所有香蕉
class Solution {
public:
int minEatingSpeed(vector<int>& piles, int h) {
// 左边界为最小速度,显然为一
// 右边界为最大速度,因为每小时最多只能吃完一堆香蕉,所以取每堆香蕉中的最大根数即可
int p1 = 1;
int p2 = 0;
for (int pile: piles) {
p2 = max(p2, pile);
}
// 二分查找
while (p1 <= p2) {
int mid = p1 + (p2 - p1) / 2;
long long int tmp = getTime(piles, mid);
if (tmp == h) {
p2 = mid - 1;
} else if (tmp < h) {
p2 = mid - 1;
} else if (tmp > h) {
p1 = mid + 1;
}
}
return p1;
}
// 给定一个速度
// 计算需要多少时间才能吃完所有香蕉
long long int getTime(vector<int>& piles, int speed) {
long long int time = 0;
for (int pile: piles) {
// (a + b - 1) / b 相当于 a / b 向上取整
time += (pile + speed - 1) / speed;
}
return time;
}
};
(6)运包裹 | leetcode1011
给定一个数组,数组中的元素 weights[i]
表示第 i
个包裹的重量,单位为吨
返回能在给定时间 d
内运走所有包裹的最小运力 c
,其中 d
的单位为天,c
的单位为吨/天
在每一天中,需要按数组顺序运走若干个包裹,要求运走包裹的重量不能超过 c
吨
解题思路
- 运包裹的运力和能否在给定时间内运走所有货物存在单调性,可以考虑用二分查找搜索最小运力
- 另一个问题是给定一个运力,怎么计算需要多少时间才能运走所有包裹
class Solution {
public:
int shipWithinDays(vector<int>& weights, int d) {
// 左边界为最小运力,至少要等于每个包裹中的最大重量
// 右边界为最大运力,因为可以一次运走所有包裹,所以取所有包裹重量的总和
int p1 = 0;
int p2 = 0;
for (int weight: weights) {
p1 = max(p1, weight);
p2 = p2 + weight;
}
// 二分查找
while (p1 <= p2) {
int mid = p1 + (p2 - p1) / 2;
int tmp = getTime(weights, mid);
if (tmp == d) {
p2 = mid - 1;
} else if (tmp < d) {
p2 = mid - 1;
} else if (tmp > d) {
p1 = mid + 1;
}
}
return p1;
}
// 给定一个运力
// 计算需要多少时间才能运走所有包裹
int getTime(vector<int>& weights, int capacity) {
int curr = 0;
int need = 1;
for (int weight: weights) {
if (curr + weight > capacity) {
curr = 0;
need = need + 1;
}
curr = curr + weight;
}
return need;
}
};