我们有N个数字,多次区间进行修改和区间求和,每次求和时输出结果。
两种思路,第一种是用线段树,每个节点维护 [L , R) 左闭右开的区间,保存2个值,1、datChild某个区间内子节点都加上的值;2、dat某个区间自己的值;
更新 线段树某个 i 节点的 [L_, R_) 区间,都加上v时的操作:
1)、如果 L_、R_与节点区间 [L , R)是不是没有交集,对 i 节点操作结束。
2)、如果 [L , R) 是不是被 [L_ , R_) 完全包含,datChild[i]+=v,然后从 i 开始(包括 i )到 i 的所有父亲节点,都加上 ( R - L ) * v
3)、如果是区间重合但不完全包含,那就再对 i 的两个子节点 i * 2 + 1 和 i * 2 + 2 进行 1)、2)、3)的递归操作,(两个子节点维护的区间分别是 [L , (L + R) / 2 )和[(L + R) / 2 , R)。)
查询线段树 [L_ , R_)的区间的和的操作
1)、如果 L_、R_与节点区间 [L , R)是不是没有交集,返回0。
2)、如果 [L , R) 是不是被 [L_ , R_) 完全包含,返回 dat[i]
3)、如果是区间重合但不完全包含,那就再对 i 的两个子节点 i * 2 + 1 和 i * 2 + 2 进行 1)、2)、3)的操作递归求出 leftSum,rightSum,之后再加上这段区间里来自父亲节点的值,最终返回的结果为
leftSum + rightSum + (min(R , R_) - max(L , L_)) * datChild[i]
(两个子节点维护的区间分别是 [L , (L + R) / 2 )和[(L + R) / 2 , R)。)
线段树的两个数组需要开 n 最接近的一个 2^i 再乘以 2 的长度,类型需要 long long。
线段树的代码其实是比较复杂,各种判断其实略带繁琐。
这里再引出树状数组Bit的操作。
建立两个树状数组,bit和bitChild,
bit[i] 用来代表某个节点维护的区间的sum,即先求出 i 的二进制的最后一个1记作j,[ i - j + 1 , i ]。
bitChild用来代表 [ i - j + 1 , i ] 区间内子节点都加上的值。
查询 [1 , R] 的sum操作(注意树状数组的下标从 1 开始不是0,且都是闭区间。
首先、设 i = R,返回结果为 sum
1)sum += bit[i],然后设 j 等于 i + (i & (-i)),之后sum += datChild[j] * (i & (-i),我们能够直到 i + ( i & ( -i)是 i 的 直系父节点,然后 i & (-i)是i节点维护区间的长度,所以这段区间内每个元素都加上的值再乘以区间的长度,就等于它本应该的值,同时需要让 j 不断循环,j = j + (j & (-j)),计算所有 j<=n的父节点。
2)然后 i = i - (i & (-i)),再重复第一步的过程,不断的循环直到计算全部的 i > 0(这是因为树状数组 bit 维护某个区间 r 的方式是 bit[r] + bit[r-(r&(-r)]+bit[r-(r&(-r))-((r-(r&(-r)))&((-1)*(r-(r&(-r)))))]+...
更新 [1,R] 的区间,都加上v的操作如下
首先,设 i + R
1)将 bitChild[i]+=v,然后计算区间长度与更新值的乘积,result=(i & (-i))*v,之后设j = i,不断的让 bit[j]+=result,之后循环j = j + (j&(-j)),计算完所有 j <=n的j。
2)然后 i = i - (i & (-i)),再重复第一步的过程,不断的循环直到计算全部的 i > 0(这是因为树状数组 bit 维护某个区间 r 的方式是 bit[r] + bit[r-(r&(-r)]+bit[r-(r&(-r))-((r-(r&(-r)))&((-1)*(r-(r&(-r)))))]+...
(备注下,i & (-i)就是i最右边那个1,它也代表着 i 节点 控制的范围是 [i-(i&(-i)+1,i] 比如
6
=0110(二进制)
=0100 + 0010(二进制)
=4 + 2 (十进制)
所以6节点的维护是 [ 6-2+1,6]区间,即[5,6],然后最右边那个1是0010也就是2
树状树状肯定比线段树难理解一点,毕竟线段树只是普通的完美二叉树
而树状数组能够更简单的实现并且只占用线段树二分之一的空间。且代码简洁干净,实现很快而且没有用到递归。
但是在这个区间修改问题中,树状数组的速度略慢于线段树(慢200ms)
#include
using namespace std;
typedef long long ll;
int num[100007], n, n_;
ll dat[262150], datChild[262150];
void input()
{
for (int i = 0; i < n_; i++)
{
scanf("%d", &num[i]);
}
}
void init()
{
n = 1;
while (n < n_)
{
n = n * 2;
}
for (int i = ((n * 2) - 2); i >= 0; i--)
{
dat[i] = 0LL;
datChild[i] = 0LL;
if (i >= (n - 1))
{
if ((i - (n - 1)) < n_)
{
dat[i] = ((ll)num[i - (n - 1)]);
}
}
else
{
dat[i] = dat[i * 2 + 1] + dat[i * 2 + 2];
}
}
}
void update(int l_, int r_, int c, int i, int l, int r)
{
if (l_ >= r || r_ <= l)
{
return;
}
else if (l >= l_ && r <= r_)
{
datChild[i] = datChild[i] + ((ll)c);
ll res = ((ll)c) * ((ll)(r - l));
dat[i] = dat[i] + res;
while (i > 0)
{
i = (i - 1) / 2;
dat[i] = dat[i] + res;
}
}
else
{
update(l_, r_, c, i * 2 + 1, l, (l + r) / 2);
update(l_, r_, c, i * 2 + 2, (l + r) / 2, r);
}
}
ll query(int l_, int r_, int i, int l, int r)
{
if (l_ >= r || r_ <= l)
{
return 0LL;
}
else if (l >= l_ && r <= r_)
{
return dat[i];
}
else
{
ll res = 0LL;
res = res + query(l_, r_, i * 2 + 1, l, (l + r) / 2);
res = res + query(l_, r_, i * 2 + 2, (l + r) / 2, r);
res = res + ((ll)(min(r, r_) - max(l, l_))) * datChild[i];
return res;
}
}
int main()
{
int q;
while (~scanf("%d%d", &n_, &q))
{
input();
init();
char c;
int l, r, v;
while (q--)
{
scanf("\n%c", &c);
if (c == 'Q')
{
scanf("%d%d", &l, &r);
printf("%lld\n", query(l - 1, r, 0, 0, n));
}
else if (c == 'C')
{
scanf("%d%d%d", &l, &r, &v);
update(l - 1, r, v, 0, 0, n);
}
}
}
return 0;
}
#include
using namespace std;
// 131072
typedef long long ll;
int num[100007], n, n_;
ll bitChild[131080], bit[131080];
void input()
{
for (int i = 1; i <= n_; i++)
{
scanf("%d", &num[i]);
}
}
void init()
{
n = 1;
while (n < n_)
{
n = n * 2;
}
for (int i = 0; i <= n; i++)
{
bit[i] = 0LL;
bitChild[i] = 0LL;
}
}
//[1,r]每个值都加上v
void update(int r, int v)
{
for (int i = r; i > 0; i = i - (i & (-i)))
{
bitChild[i] = bitChild[i] + ((ll)v);
ll result = ((ll)(i & (-i))) * ((ll)v);
for (int j = i; j <= n; j = j + (j & (-j)))
{
bit[j] = bit[j] + result;
}
}
}
//[1,r]的sum
ll query(int r)
{
ll result = 0LL;
for (int i = r; i > 0; i = i - (i & (-i)))
{
result = result + bit[i];
for (int j = i + (i & (-i)); j <= n; j = j + (j & (-j)))
{
result = result + bitChild[j] * (i & (-i));
}
}
return result;
}
void push()
{
for (int i = 1; i <= n_; i++)
{
update(i, num[i]);
update(i - 1, num[i] * (-1));
}
}
int main()
{
int l, r, v, q;
char c;
while (~scanf("%d%d", &n_, &q))
{
input();
init();
push();
while (q--)
{
scanf("\n%c", &c);
if (c == 'Q')
{
scanf("%d%d", &l, &r);
ll all = query(r);
ll left = query(l - 1);
printf("%lld\n", all - left);
}
else if (c == 'C')
{
scanf("%d%d%d", &l, &r, &v);
update(r, v);
update(l - 1, v * (-1));
}
}
}
return 0;
}