树状数组(Fenwick tree,又名binary indexed tree),是一种很实用的数据结构。它通过用节点i,记录数组下标在[ i –2^k + 1, i]这段区间的所有数的信息(其中,k为i的二进制表示中末尾0的个数,设lowbit(i) = 2^k),实现在O(lg n) 时间内对数组数据的查找和更新。
树状数组的传统解释图,不能很直观的看出其所能进行的更新和查询操作。其最主要的操作函数lowbit(k)与数的二进制表示相关,本质上仍是一种二分。因而可以通过二叉树,对其进行分析。事实上,从二叉树图,我们对它所能进行的操作和不能进行的操作一目了然。
和前面提到的点树类似,先画一棵二叉树,然后对节点中序遍历(点树是采用广度优先),每个节点仍然只记录左子树信息,见图:
由于采用的是中序遍历,从节点1到节点k时,刚好有k个叶子被统计。
可以证明:
叶子k,一定在节点k的左子树下。
以节点k为根的树,其左子树共有叶子lowbit(k)
节点k的父节点是:k + lowbit(k) 或 k - lowbit(k)
节点k + lowbit(k) 是节点k的最近父节点,且节点k在它的左子树下。
节点k - lowbit(k) 是节点k的最近父节点,且节点k在它的右子树下。
节点k,统计的叶子范围为:(k - lowbit(k), k]。
节点k的左孩子是:k - lowbit(k) / 2
下面分析树状数组两面主要应用:
1 更新数据x,进行区间查询。
2 更新区间,查询某个数。
由于,树状数组只统计了左子树的信息,因而只能查询更新区间[1, x]。只在在满足[x,y]的信息可以由[1,x-1]和[1,y]的信息推导出时,才能进行区间[x,y]的查询更新。这也是树状数组不能用于任意区间求最值的根本原因。
先定义两个集合:
up_right(k) : 节点k所有的父节点,且节点k在它们的左子树下。
up_left(k) : 节点k所有的父节点,且节点k在它们的右子树下。
1 更新数据x,查询区间[1,y]。
显然,更新叶子x,要找出叶子x在哪些节点的左子树下。因而节点k、所有的up_right(k)
都要更新。
查询[1, y],实际上就是把该区间拆分成一系列小区间,并找出统计这些区间的节点。可以通过找出y在哪些节点的右子树下,这些节点恰好不重复的统计了区间[1, y-1]。因而要访问节点y、所有的up_left(y)。
2 更新区间[1,y],查询数据x
这和前面的操作恰好相反。与前面的最大不同之处在于:节点保存的不再是其叶子总个数这些信息,而是该区间的所有叶子都改变了多少。也就是说:每个叶子的信息,分散到了所有对它统计的节点上。因此操作和前面相似:
更新[1,y]时,更新节点y、所有up_left(y)。
查询x时, 访问x、所有up_right(x)。
前面的树状数组,只对左子树信息进行统计,如果从后往前读数据初始化树状数组,则变成只对右子树信息进行统计,这时更新和查询操作,刚好和前面的相反。
一般情况下,树状数组比点树省空间,对区间[1, M]只要M+1空间,查询更新时定位节点比较快,定位父节点和左右孩子相对麻烦点(不过,一般也不用到。从上往下查找,可参考下面代码中的erease_nth函数(删除第n小的数))。
下面是使用树状数组的实现代码(求逆序数和模拟约瑟夫环问题):
树状数组
//
www.cnblogs.com/flyinghearts
#include
<
cstdio
>
#include
<
cstring
>
#include
<
cassert
>
template
<
int
N
>
struct
Round2k
{
enum
{ down
=
Round2k
<
N
/
2u
>
::down
*
2
}; };
template
<>
struct
Round2k
<
1
>
{
enum
{ down
=
1
}; };
template
<
int
Total, typename T
=
int
>
//
区间[1, Total]
class
BIT {
enum
{ Min2k
=
Round2k
<
Total
>
::down};
T info[Total
+
1
];
T sz;
//
可以用info[0]储存总大小
public
:
BIT() { clear(); }
void
clear() { memset(
this
,
0
,
sizeof
(
*
this
));}
int
size() {
return
sz; }
int
lowbit(
int
idx) {
return
idx
&
-
idx;}
//
寻找最近的父节点,left_up/right_up 分别使得idx在其右/左子树下
void
left_up(
int
&
idx) { idx
-=
lowbit(idx); }
void
right_up(
int
&
idx) { idx
+=
lowbit(idx); }
void
update(
int
idx ,
const
int
val
=
1
) {
//
叶子idx 改变val个
assert(idx
>
0
);
sz
+=
val;
for
(; idx
<=
Total; right_up(idx)) info[idx]
+=
val;
}
void
init(
int
arr[],
int
n) {
//
arr[i]为叶子i+1的个数
assert(n
<=
Total);
sz
=
n;
//
for (int i = 0; i < n; ) {
//
info[i + 1] = arr[i];
//
if (++i >= n) break;
//
info[i + 1] = arr[i];
//
++i;
//
for (int j = 1; j < lowbit(i); j *= 2u) info[i] += info[i - j];
//
}
for
(
int
i
=
0
; i
<
n; ) {
info[i
+
1
]
=
arr[i];
if
(
++
i
>=
n)
break
;
int
sum
=
arr[i];
int
pr
=
++
i;
left_up(pr);
for
(
int
j
=
i
-
1
; j
>
pr; left_up(j)) sum
+=
info[j];
info[i]
=
sum;
}
}
int
count(
int
idx) {
//
[1,idx] - [1, idx-1]
assert(idx
>
0
);
int
sum
=
info[idx];
//
int pr = idx;
//
int pr = idx - lowbit(idx);
//
left_up(pr);
//
for (--idx; idx > pr; left_up(idx)) sum -= info[idx];
//
//
return sum;
for
(
int
j
=
1
; j
<
lowbit(idx); j
*=
2u
) sum
-=
info[idx
-
j];
return
sum;
}
int
lteq(
int
idx) {
//
小等于
assert(idx
>=
1
&&
idx
<=
Total);
int
sum
=
0
;
for
(; idx
>
0
; left_up(idx)) sum
+=
info[idx];
return
sum;
}
int
gt(
int
idx) {
return
sz
-
lteq(idx); }
//
大于
int
operator
[](
int
n) {
return
erase_nth(n,
0
); }
//
第n小
int
erase_nth(
int
n,
const
bool
erase_flag
=
true
)
//
删除第n小的数
{
assert(n
>=
1
&&
n
<=
sz);
sz
-=
erase_flag;
int
idx
=
Min2k;
//
从上往下搜索,先定位根节点
for
(
int
k
=
idx
/
2u
; k
>
0
; k
/=
2u
) {
int
t
=
info[idx];
if
(n
<=
info[idx]) { info[idx]
-=
erase_flag; idx
-=
k;}
//
进入左子树
else
{
n
-=
t;
if
(Total
!=
Min2k
&&
Total
!=
Min2k
-
1
)
//
若不是完全二叉树
while
(idx
+
k
>
Total) k
/=
2u
;
//
则必须计算右孩子的编号
idx
+=
k;
//
进入右子树
}
}
assert(idx
%
2u
);
//
最底层节点m一定是奇数,有两个叶子m,m+1
if
(n
>
info[idx])
return
idx
+
1
;
//
节点m+1前面已经更新过
info[idx]
-=
erase_flag;
return
idx;
}
void
show()
{
for
(
int
i
=
1
; i
<=
Total;
++
i)
if
(count(i)) printf(
"
%2d
"
, i);
printf(
"
\n
"
);
}
};
void
ring()
//
约瑟夫环
{
const
int
N
=
17
;
//
N个人编号:1,2, ... N
const
int
M
=
7
;
//
报数:1到M,报到M的出列
printf(
"
N: %d M: %d\n
"
, N, M);
BIT
<
N
>
pt;
//
for (int i = 0; i < N; ++i) pt.update(i + 1);
int
arr[N];
for
(
int
i
=
0
; i
<
N;
++
i) arr[i]
=
1
;
pt.init(arr, N);
for
(
int
j
=
N, k
=
0
; j
>=
1
;
--
j) {
k
=
(k
+
M
-
1
)
%
j;
int
t
=
pt.erase_nth(k
+
1
);
printf(
"
turn: %2d out: %2d rest:
"
, N
-
j, t);
pt.show();
}
printf(
"
\n\n
"
);
}
int
ra(
int
arr[],
int
len)
//
求逆序数-直接搜索
{
int
sum
=
0
;
for
(
int
i
=
0
; i
<
len
-
1
;
++
i)
for
(
int
j
=
i
+
1
; j
<
len;
++
j)
if
(arr[i]
>
arr[j])
++
sum;
return
sum;
}
template
<
int
N
>
int
rb(
int
arr[],
int
len)
//
求逆序数-使用树状数组
{
BIT
<
N
>
pt;
int
sum
=
0
;
for
(
int
i
=
0
; i
<
len;
++
i) {
pt.update(arr[i]
+
1
);
sum
+=
pt.gt(arr[i]
+
1
);
}
return
sum;
}
int
main()
{
int
arr[]
=
{
4
,
3
,
2
,
1
,
0
,
5
,
1
,
3
,
0
,
2
};
const
int
N
=
sizeof
(arr)
/
sizeof
(arr[
0
]);
printf(
"
%d %d\n\n
"
, ra(arr, N), rb
<
6
>
(arr, N));
ring();
}