RMQ-ST算法[转]

转自这个博客 文章中间我会利用/**/符号附一些自己的理解

概述

RMQ(Range Minimum/Maximum Query),即区间最值查询,是指这样一个问题:对于长度为n的数列A,回答若干询问RMQ(A,i,j)(i,j<=n),返回数列A中下标在i,j之间的最小/大值。这两个问题是在实际应用中经常遇到的问题,下面介绍一下解决这两种问题的比较高效的算法。当然,该问题也可以用线段树(也叫区间树)解决,算法复杂度为:O(N)~O(logN),这里我们暂不介绍。

RMQ算法

对于该问题,最容易想到的解决方案是遍历,复杂度是O(n)。但当数据量非常大且查询很频繁时,该算法无法在有效的时间内查询出正解。

本节介绍了一种比较高效的在线算法(ST算法)解决这个问题。所谓在线算法,是指用户每输入一个查询便马上处理一个查询。该算法一般用较长的时间做预处理,待信息充足以后便可以用较少的时间回答每个查询。ST(Sparse Table)算法是一个非常有名的在线处理RMQ问题的算法,它可以在O(nlogn)时间内进行预处理,然后在O(1)时间内回答每个查询。


[一] 首先是预处理,用动态规划(DP)解决。

设A[i]是要求区间最值的数列,F[i, j]表示从第i个数起连续2^j个数中的最大值。(DP的状态)

例如:

A数列为:3 2 4 5 6 8 1 2 9 7

F[1,0]表示第1个数起,长度为2^0=1的最大值,其实就是3这个数。同理 F[1,1] = max(3,2) = 3, F[1,2]=max(3,2,4,5) = 5,F[1,3] = max(3,2,4,5,6,8,1,2) = 8;

并且我们可以容易的看出F[i,0]就等于A[i]。(DP的初始值)

这样,DP的状态、初值都已经有了,剩下的就是状态转移方程。

我们把F[i,j]平均分成两段(因为f[i,j]一定是偶数个数字, /* 因为序列的长度是2^j */),从 i 到i + 2 ^ (j - 1) - 1为一段,i + 2 ^ (j - 1)到i + 2 ^ j - 1为一段(长度都为2 ^ (j - 1))。用上例说明,当i=1,j=3时就是3,2,4,5 和 6,8,1,2这两段。F[i,j]就是这两段各自最大值中的最大值。于是我们得到了状态转移方程F[i, j]=max(F[i,j-1], F[i + 2^(j-1),j-1])。

代码如下:

void RMQ(int num) {
  for (int j = 1; j < 20; ++j)
    for (int i = 1; i <= num; ++i)
      if (i + (1 << j) - 1 <= num) {
        maxsum[i][j] = max(maxsum[i][j - 1], maxsum[i + (1 << (j - 1))][j - 1]);
        minsum[i][j] = min(minsum[i][j - 1], minsum[i + (1 << (j - 1))][j - 1]);
      }
}

这里我们需要注意的是循环的顺序,我们发现外层是j,内层是i,这是为什么呢?可以是i在外,j在内吗?

答案是不可以。我们需要理解这个状态转移方程的意义。

状态转移方程的含义是:先更新所有长度为F[i,0]即1个元素,然后通过2个1个元素的最值,获得所有长度为F[i,1]即2个元素的最值,然后再通过2个2个元素的最值,获得所有长度为F[i,2]即4个元素的最值,以此类推更新所有长度的最值。

而如果是i在外,j在内的话,我们更新的顺序就是F[1,0],F[1,1],F[1,2],F[1,3],表示更新从1开始1个元素,2个元素,4个元素,8个元素(A[0],A[1],....A[7])的最值,这里F[1,3] = max(max(A[0],A[1],A[2],A[3]),max(A[4],A[5],A[6],A[7]))的值,但是我们根本没有计算max(A[0],A[1],A[2],A[3])和max(A[4],A[5],A[6],A[7]),所以这样的方法肯定是错误的。

为了避免这样的错误,一定要好好理解这个状态转移方程所代表的含义。


[二] 然后是查询

假如我们需要查询的区间为(i,j),那么我们需要找到覆盖这个闭区间(左边界取i,右边界取j)的最小幂(可以重复,比如查询5,6,7,8,9,我们可以查询5678和6789)。

因为这个区间的长度为j - i + 1,所以我们可以取k=log2( j - i + 1),则有:RMQ(A, i, j)=max{F[i , k], F[ j - (2 ^ k) + 1, k]}。
/*
* F[i, k]从数学角度已经覆盖了长度j-i+1,但由于index是整形
* 如下例子中,F[2,2]并不能覆盖[2, 8],仅仅是覆盖了[2, 5],
* 于是还需要再来一个F[ j - (2^k)+1, k ] ,虽然和前面的F可能有重叠,
* 但是能完全考虑整个区间(也就是所谓的覆盖的区间可以重复))
*/
举例说明,要求区间[2,8]的最大值,k = log2(8 - 2 + 1)= 2,即求
max(F[2, 2],F[8 - 2 ^ 2 + 1, 2]) = max(F[2, 2],F[5, 2]);

在这里我们也需要注意一个地方,就是<<运算符和+-运算符的优先级。

比如这个表达式:5 - 1 << 2是多少?
答案是:(5 - 1) * 2 * 2 = 16。所以我们要写成5 - (1 << 2)才是5-1 * 2 * 2 = 1。

例题 POJ 3368 Frequent values

意思即有一个不下降数列,求区间[i, j]内重复次数最多的数字的“重复次数”。
首先我们可以离散化处理,维护一个数组F[i],表示当前数字到i这个位置时,已经重复了的次数。
比如 1 1 1 1 3 10 10 10,对应F[] = {1, 2, 3, 4, 1, 1, 2, 3}

对于每个询问(l,r),分为两个部分,前半部分求与l之前相同的数的个数直到t,后半部分从t开始直接用RMQ求解最大值就行了。
最后结果为max(前半部分,后半部分)。
拿上面那个例子而言,求区间[2, 7]的解,从l开始,重复的数字“1”,重复到了下标5位置,后半部分从5到j,我们利用RMQ求得为2,而前半部分重复的次数则为(5 - 2) = 3, 取ans = max{3, 2} 得最终解。

这时候可能会有一个小疑问,我们离散化统计的F是一段一段的,所需要求解的区间可能截断这个区间,使得F统计的数据不一定能用上。的确如此,在上面的方法中,由于我们分了前半部分和后半部分,所以前半部分被截断的问题被解决了,那中间部分后面的末尾部分难道不需要考虑截断,分类讨论吗?

答案是不需要,因为F[i]维持的是当前数字到i这个位置时,已经重复了的次数。比如F[7] = 2, 即10重复的次数到目前重复的次数是2,在F的尾巴中,已经是一个数字一个数字的截断统计了。

代码来自这个博客

#include 
#include 
#include 
using namespace std;

const int maxN=1e5+5;
int N, M, K, T;
int g[maxN], S[maxN];
int d[maxN][20];

void RMQ_init(int *A) {
    for (int i = 0; i < N; ++i) d[i][0] = A[i];
    for (int j = 1; (1 << j) <= N; ++j)
        for (int i = 0; i + (1 << j) - 1 < N; ++i)
            d[i][j] = max(d[i][j - 1], d[i + (1 << (j - 1))][j - 1]);
}
int RMQ(int L, int R) {
    int k = 0;
    while (1 << (k + 1) <= R - L + 1) ++k;
    return max(d[L][k], d[R - (1 << k) + 1][k]);
}

int main () {
#ifndef ONLINE_JUDGE
    freopen("data.in", "r", stdin);
#endif
    while (~scanf("%d", &N) && N)  {
        scanf("%d", &M);
        memset(S, 0, sizeof S);
        for (int i = 0; i < N; ++i) {
            scanf("%d", &g[i]);
            S[i] = 1;
            if (i && g[i] == g[i - 1])
                S[i] = S[i - 1] + 1;
        }
        RMQ_init(S);
        int u, v;
        while (M--) {
            scanf("%d%d", &u, &v);
            --u, --v;
            int idx = u + 1;
            while (idx <= v && g[idx] == g[idx - 1])
                ++idx;
            int ans;
            if (idx > v) ans = v - u + 1;
            else ans = max(idx - u, RMQ(idx, v));
            printf("%d\n", ans);
        }
    }
    return 0;
}

最近又做到一题利用到了这个RMQ-ST算法

来自hihoCoder的题目链接
题目大意是给定N个整数A1..An, 有M个询问,问[L,R]区间内最长的等差连续自数列长度是多少?
其中1 <= N, M <= 1000000 <= Ai <= 10000000

分析:首先,我们先用N时间从头到尾拉一遍数列,得到数组f[], 其中f[i]的意思是:到i为止的最长等差连续子数列长度. 打个比方:
数列 1 2 3 5 7 9
应该对应f[]的 1 2 3 2 3 4
怎么计算这个值? 首先数字本身是长度为1的等差数列,且任意两个相邻的数构成长度为2的数列:可以写出:

f[1] = 1, f[2] = 2;
for (int i = 3; i <= N; ++i) {
    if (num[i] - num[i - 1] == num[i - 1] - num[i - 2]) {
        f[i] = f[i - 1] + 1;
    else
        f[i] = 2;
}

然后,对于每次查询 [L, R] 我们和上一题一样分成[L, t] 和[t + 1, R], t作为分水岭,目的是修正左侧的被截断项, 我们独立计算从L开始,等差数列的元素个数是多少, 一个个去计算得为cnt_t,
那么左边我们便知,等差数列有cnt_t个元素,对于[t + 1, R], 由于这个区间中新等差数列的源头下标必然>=L, 所以f[]的值无需修正,直接RMQ_ST算法得到f[]中,[t+1, R]的最大值,然后和cnt_t比较,取二者之大即为解.

#include 
using namespace std;

const int maxN = 1000005, inf = 0x3f3f3f3f;
int num[maxN], f[maxN], arr[maxN][20];
int N, M;

void ST() {
    int i, j, k;
    for (int i = 1; i <= N; ++i)
        arr[i][0] = f[i];
    k = log((double)(N + 1)) / log(2.0);
    for (j = 1; j <= k; ++j)
        for (i = 1; i + (1 << j) - 1 <= N; ++i)
            arr[i][j] = max(arr[i][j - 1], arr[i + (1 << (j - 1))][j - 1]);
}

int rmq_max(int l, int r) {
    if (l > r)
        return 0;
    int k = log((double)(r - l + 1)) / log(2.0);
    return max(arr[l][k], arr[r - (1 << k) + 1][k]);
}

int main() {
    // freopen("data.in", "r", stdin);
    scanf("%d %d", &N, &M);
    for (int i = 1; i <= N; ++i) {
        scanf("%d", &num[i]);
    }
    f[1] = 1, f[2] = 2;
    for (int i = 3; i <= N; ++i) {
        if (num[i] - num[i - 1] == num[i - 1] - num[i - 2])
            f[i] = f[i - 1] + 1;
        else
            f[i] = 2;
    }
    ST();
    // for (int i = 1; i <= N; ++i) printf("%d ", f[i]);
    // 1 2 3 5 7 9
    // 1 2 3 2 3 4
    int L, R;
    for (int i = 1; i <= M; ++i) {
        scanf("%d %d", &L, &R);
        // 一样要分成2部分[L, t] 和[t+1, R]
        if (L + 1 > R) {
            printf("1\n");
            continue;
        }
        int cnt_t = 1, diff = 0;
        diff = num[L + 1] - num[L];
        int t;
        for (t = L + 1; t <= R; ++t) {
            if (num[t] - num[t - 1] == diff)
                ++cnt_t;
            else
                break;
        }
        int ans = max(cnt_t, rmq_max(t, R));
        printf("%d\n", ans);
    }
    return 0;
}

你可能感兴趣的:(RMQ-ST算法[转])