题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4027
这道线段树的题目我并没有按照线段树的模板写,这里把要点讲出来就行了。
第一:区间开根号不像区间乘或者除,能通过区间和然后一次性乘除求出一个区间的乘除,所以区间开根号只能一个一个的开,但是如果一个一个的开不会超时吗? 假设一个数在long long的范围里面,也就是2^63这个范围内,那么对这个数开根号最多要开几次? 63次。也就是说,开根号能够在很少的次数内,让一个数一直趋近于一个固定值1,这是区间开根号的一个突破口,只要知道当前这个区间的值都等于或者小于1,那就不需要对这个区间做开根号的操作了。虽然一个一个的开根号看似很费时,但是事实上是可行的,我们假设有 10^5(题目N的上限)个数,每个数都是2^63,每次都对整个区间开根号我们需要花费多少时间,因为线段树是一个把数组二叉树化的结构,所以每次访问一个数,需要log(N)的时间复杂度,有N个数,并且每个数最多开log(2^63)次,那么时间复杂度大约是 15*63*10^5;而且题目有说明,所有数的总和是不超过 2^63的,所以实际的时间复杂度还比上面说的要小很多;
所以具体的代码要怎么写?
首先需要一个记录区间最大值的数组,只要当前区间的最大值是小于或者等于1,那么就不用继续访问下面的节点了,因为1开根号仍然是1,这能为我们省下不少时间,其次是一个区间和的数组。
每次区间修改,都需要遍历到最底部的叶子节点,除非中途遇到区间最大值是小于或者等于1的就不用向下遍历了,修改叶子节点的值,并且更新最大值,和记录开根号后减少的值,在回溯的过程中更新区间最大值和区间和。
第二:这里的数据有可能出现 X > Y 的情况,所以要记得判断,这是题目的一个坑点;
下面是代码,因为我没有用线段树的模板,所以重点在上面的叙述,最近在尝试着脱离模板做题。
#include
#include
#include
#include
#include
using namespace std;
typedef long long ll;
const int Maxn = 1e5+10;
const int INF = 0x3f3f3f3f;
ll a[Maxn], maxn[4*Maxn], sum[4*Maxn]; // 把区间和和区间最大值设为全局,在更新区间和和区间最值得时候
int L, R; // 通过访问下标就能很容易的更新数组
// 2*cur, 2*cur+1, cur/2,左右节点和父节点
ll init_max(int cur, int l, int r) {
if(l == r) {
maxn[cur] = a[l]; return a[l];
}
int mid = (l+r)/2;
maxn[cur] = max(init_max(cur*2, l, mid), init_max(cur*2+1, mid+1, r));
return maxn[cur];
}
ll init_sum(int cur, int l, int r) {
if(l == r) {
sum[cur] = a[l]; return a[l];
}
int mid = (l+r)/2;
sum[cur] = init_sum(cur*2, l, mid)+init_sum(cur*2+1, mid+1, r);
return sum[cur];
}
void updata(int cur, int l, int r) { // 更新区间和的方法是记录当前节点修改后与修改前的差值,
if (maxn[cur] <= 1) return; // 通过 cur/2 直接修改父节点的值。
ll tmp;
if (l == r) {
tmp = a[l]-(ll)sqrt(a[l]);
sum[cur/2] -= tmp; a[l] -= tmp;
sum[cur] = a[l]; maxn[cur] = a[l];
return;
}
tmp = sum[cur];
int mid = (r+l)/2;
if (L <= mid) updata(cur*2, l, mid);
if (mid+1 <= R) updata(cur*2+1, mid+1, r);
maxn[cur] = max(maxn[cur*2], maxn[cur*2+1]); // 这里更新区间最大值和区间和可以直接通过数组下标得到
sum[cur/2] -= (tmp-sum[cur]);
}
ll solve(int cur ,int l, int r) {
if (L <= l && r <= R) return sum[cur];
int mid = (r+l)/2;
ll ret = 0;
if (L <= mid) ret += solve(cur*2, l, mid);
if (mid+1 <= R) ret += solve(cur*2+1, mid+1, r);
return ret;
}
int main (void)
{
int N, cas = 0;
while (scanf ("%d", &N) != EOF) {
for (int i = 1; i <= N; ++i) scanf ("%lld", &a[i]);
init_max(1, 1, N); init_sum(1, 1, N); // 预先把区间的最值和区间和求出
int M, T;
scanf ("%d", &M);
printf("Case #%d:\n", ++cas);
while (M--) {
scanf ("%d%d%d", &T, &L, &R);
if (L > R) swap(L, R); // HDU的坑点
if (T) {
printf("%lld\n", solve(1, 1, N));
} else {
updata(1, 1, N);
}
}
printf("\n");
}
return 0;
}