我的算法不可能这么简单—ST表

文章目录

  • ST表
  • 题目
    • ST表的引出
    • ST表实现
      • 预处理
      • 查询
  • 例题总代码
  • 额外经验

ST表

  • ST表 (sparse table)即稀疏表

它可以在 O ( n l o g 2 n ) O(nlog_2 n) O(nlog2n)内预处理, O ( 1 ) O(1) O(1)内查询:

  •    1. 区间最大值
       2. 区间最小值
       3. 区间最大公约数
       4. 区间最小公倍数
    

满足吸收率的操作貌似都可以? (吸收率,详见离散数学代数系统的吸收率)


  • ST表为离线算法,因此区间给定后,不能进行修改,否则整张表将要重新计算,时间复杂度将会变得非常高。对于带有修改的操作可以使用线段树或树状数组。那我还学什么ST表?树状数组它不香吗?
  • ST表依旧被广为使用的原因是其优秀的时间复杂度,以及码量少。其实树状数组的码量好像更少,但是权当学习倍增思想了。

题目

洛谷-> P3865 【模板】ST表

我的算法不可能这么简单—ST表_第1张图片
当然这道题用线段树、树状数组。

ST表的引出

我们以上面这道题为例,详细的引出ST表。

  • 最好想的朴素算法
    它查询哪个区间,我们直接循环遍历那个区间,求出最大值。
#include 
using namespace std;
#define int long long

//快读
inline int read(){
     
    int x=0,f=1;char ch=getchar();
    while (!isdigit(ch)){
     if (ch=='-') f=-1;ch=getchar();}
    while (isdigit(ch)){
     x=x*10+ch-48;ch=getchar();}
    return x*f;
}
#define read read() //我觉得写俩括号太难了

const int maxn = 1e5+9;
int a[maxn],n,m;

signed main(){
     
    n=read,m=read;
    for(int i=1;i<=n;++i)
        a[i] = read;
    while(m--){
     
        int l=read,r=read;
        int mn = LONG_LONG_MIN;
        for(int i=l;i<=r;++i)
            mn = max(mn,a[i]);
        printf("%lld\n",mn);
    }

    return 0;
}

很显然时间复杂度是 O ( n m ) O(nm) O(nm) 的,对这题来说,必超时。

我们考虑一下能不能优化一下。考虑到它的区间是不变的,我们可以进行打表,预处理出所有区间的最大值,查询的时候可以直接查出来。

  • 打表优化(动态规划)
    预处理出所有区间的最大值,查询的时候直接输出。

首先我们定义一个 i 行 j 列的数组 int ans[i][j] ,用来表示区间 [ i , j ] [i,j] [i,j] 的最大值。
我们需要找一下打表的方法(状态转移方程)

  • 很显然当 i = = j i==j i==jans[i][j] = a[j]

我的算法不可能这么简单—ST表_第2张图片

  • 由上图不难看出当 i ≠ j i ≠ j i=j 时,状态转移方程为 ans[i][j] = max(ans[i][j-1] , a[j]);

因此我们可以写出打表代码:

#include 
using namespace std;
#define int long long

//快读
inline int read(){
     
    int x=0,f=1;char ch=getchar();
    while (!isdigit(ch)){
     if (ch=='-') f=-1;ch=getchar();}
    while (isdigit(ch)){
     x=x*10+ch-48;ch=getchar();}
    return x*f;
}
#define read read() //我觉得写俩括号太难了

const int maxn = 1e5+9;
int a[maxn],n,m;
int ans[maxn][maxn];

signed main(){
     
    n=read,m=read;
    for(int i=1;i<=n;++i)
        a[i] = read;
    //预处理
    for(int i=1;i<=n;++i) 
        for(int j=i;j<=n;j++){
     
            if(i==j)  ans[i][j] = a[j];
            else      ans[i][j] = max(ans[i][j-1],a[j]);
        }

    while(m--){
     
        int l=read,r=read;
        printf("%lld\n",ans[l][r]);
    }

    return 0;
}

对代码进行分析分析(事实上这个代码在我电脑上根本运行不了,内存太大了 ),首先是空间复杂度,很显然的 O ( n 2 ) O(n^2) O(n2),其次是时间复杂度,刚开始进行了一次预处理,预处理是 O ( n 2 ) O(n^2) O(n2) 的,然后进行了 m 次询问,所以总时间复杂度为 O ( n 2 + m ) O(n^2+m) O(n2+m) ,这个复杂度比朴素算法要好一些 (因为m比n大的多),但是无论是空间还是时间都还是不能满足该题目。

我们需要继续优化,在上面的预处理中,我们可以发现,每次更新都只是将区间扩大了1个,这必然会导致非常多的重复值,我们能不能一次将区间扩大很多个,同时又能保证每个区间都能被覆盖到。于是乎考虑—倍增!千呼万唤始出来

ST表实现

首先定义一个数组 int ans[i][j] 其表示的意义不再是区间 [ i , j ] [i,j] [i,j]的最大值,而是借助了倍增思想,每次扩充 2 j 2^j 2j 个数,故表示的是从 i 开始长度为 2 j 2^j 2j 的区间 , 即 [ i , i + 2 j − 1 ] [i,i+2^j-1] [i,i+2j1] 这个区间的最大值。

预处理

现在我们来看一下倍增思想的预处理是什么样的。

  • 很显然 ans[i][0] 表示的是区间 [ i , i + 2 0 − 1 ] [i,i+2^0-1] [i,i+201] [ i , i ] [i,i] [i,i] , 所以 ans[i][0] = a[i]
  • 由于 a [ i ] a[i] a[i] 已经被用过了,所以我们可以知道 a n s [ i ] [ j ] ans[i][j] ans[i][j] 的转移方程不会再与 a [ i ] 或 a [ j ] a[i]或a[j] a[i]a[j] 产生关系

我的算法不可能这么简单—ST表_第3张图片

  • 我们需要找到两个已经处理完毕的区间,并且这两个小区间能够覆盖住新的更大的区间,很明显我们能够想到 ans[i][j-1] ,通过上面的图我们可以找到另一块更小的区间 a n s [ i + 2 j − 1 ] [ j − 1 ] ans[i+2^{j-1}][j-1] ans[i+2j1][j1], 因此可以得到状态转移方程 : ans[i][j] = max(ans[i][j-1],ans[i+(1<<(j-1))][j-1])
  • 并且我们需要保证每次转移状态时 a n s [ i + 2 j − 1 ] [ j − 1 ] ans[i+2^{j-1}][j-1] ans[i+2j1][j1] 已经被更新过,因此我们的外层循环应该是 j ,而内层循环应该是 i,因为 i 更新的速度快。

于是我们的预处理代码为:

void proc(){
     
    for(int i=1;i<=n;++i)
        ans[i][0] = read; //ans[i][0] = a[i] 所以我们没有必要再开一个a数组,直接输入即可
    for(int j=1;j<=log2(n);++j)
        for(int i=1;i+(1<<j)-1<=n;++i)//i+(1<
            ans[i][j] = max(ans[i][j-1],ans[i+(1<<(j-1))][j-1]);
}

此时预处理的时间复杂度为 O ( n l o g 2 n ) O(nlog_2n) O(nlog2n)

查询

在上面我们已经预处理出了所有 [ i , i + 2 j − 1 ] [i,i+2^j-1] [i,i+2j1]的区间。给定 l , r l,r l,r 我们怎么查询 [ l , r ] [l,r] [l,r]的最大值呢?

  • 不妨以 [ 1 , 14 ] [1,14] [1,14] 来说明一下,为了查询我们需要将区间拆分成长度为 2 k 2^k 2k的小区间,不难算出 [ 1 , 14 ] = [ 1 , 8 ] ∪ [ 9 , 12 ] ∪ [ 13 , 14 ] [1,14] = [1,8] ∪ [9,12] ∪ [13,14] [1,14]=[1,8][9,12][13,14]
  • 因此区间 [ l , r ] [l,r] [l,r] 的最大值为 max(ans[1][3],ans[9][2],ans[13][1])
  • 而更一般的该如何拆分呢,考虑二进制, [ 1 , 14 ] [1,14] [1,14]区间长度为14,14的二进制为 1110,也即 14 = 2 3 + 2 2 + 2 1 14=2^3+2^2+2^1 14=23+22+21 ,这个时候你应该明白了我们的小区间是如何拆分的,事实上任何一个数 n 都能拆成形如 2 a 1 + 2 a 2 + 2 a 3 + . . . + . . . 2 a n 2^{a_1}+2^{a_2}+2^{a_3}+...+...2^{a_n} 2a1+2a2+2a3+...+...2an 的形式的
  • 所以我们可以求出任何一个区间长度的形如上述的表示方法,然后求出所有小区间的最大值即为要求区间的最大值。

此时的单次查询时间复杂度为 O ( l o g 2 n ) O(log_2n) O(log2n),总的时间复杂度为 O ( n l o g 2 n + m l o g 2 n ) O(nlog_2n+mlog_2n) O(nlog2n+mlog2n),此时仍不能达到通过题目的程度。我们在开始已经说过ST表查询时间复杂度可以达到 O ( 1 ) O(1) O(1),我们需要继续优化。

事实上,我们真的有必要将区间划分为这么多的小区间来查询吗?

  • 对于一段区间,拆分成这样
    我的算法不可能这么简单—ST表_第4张图片
    和拆分成这样
    我的算法不可能这么简单—ST表_第5张图片
    似乎没有什么区别,因为最大值满足吸收率,所以就算出现交集也不会影响到最终结果。所以我们根本没必要将区间划分成这么多份,而仅仅需要划分成两份即可。
  • 假设区间 [ l , r ] [l,r] [l,r] 的长度为 s ,那么我们只要找到不大于 s 的最大的 2 k 2^k 2k ,然后将区间划分成为 [ l , l + 2 k − 1 ] , [ r − 2 k + 1 , r ] [l,l+2^k-1],[r-2^k+1,r] [l,l+2k1],[r2k+1,r] ,即可。
  • 为什么非要找到不大于 s 的最大的 2 k 2^k 2k 呢,因为我们要保证 r − 2 k + 1 ≤ l + 2 k − 1 r-2^k+1≤ l+2^k-1 r2k+1l+2k1 ,即这俩区间必须能够覆盖住整个区间 [ l , r ] [l,r] [l,r]
  • 而要找到不大于 s 的最大的 2 k 2^k 2k ,直接使用 int k = log2(r-l+1) 即可获得。
  • 这样无论什么区间,我们只需要返回 max(ans[l][k],ans[r-(1< 即可。

查询代码:

int query(int l,int r){
     
    int k = log2(r-l+1);
    return max(ans[l][k],ans[r-(1<<k)+1][k]);
}

至此我们已经将ST表写完!读到这里,你应该已经懂得为什么ST表只能处理满足吸收率的运算,因为ST表查询的区间是重叠起来的 ! 因此不能用来查询区间和等问题。

例题总代码

#include 
using namespace std;
#define int long long

//快读
inline int read(){
     
    int x=0,f=1;char ch=getchar();
    while (!isdigit(ch)){
     if (ch=='-') f=-1;ch=getchar();}
    while (isdigit(ch)){
     x=x*10+ch-48;ch=getchar();}
    return x*f;
}
#define read read() //我觉得写俩括号太难了

const int maxn = 1e5+9;
int a[maxn],n,m;
int ans[maxn][30];

void proc(){
     
    for(int i=1;i<=n;++i)
        ans[i][0] = read;
    for(int j=1;j<=log2(n);++j)
        for(int i=1;i+(1<<j)-1<=n;++i)
            ans[i][j] = max(ans[i][j-1],ans[i+(1<<(j-1))][j-1]);
}

int query(int l,int r){
     
    int k = log2(r-l+1);
    return max(ans[l][k],ans[r-(1<<k)+1][k]);
}

signed main(){
     
    n=read,m=read;

    proc();

    while(m--){
     
        int l=read,r=read;
        printf("%lld\n",query(l,r));
    }

    return 0;
}

  • 但是由于 log2() 的复杂度不明确,所以对于 m 次查询(而且m远远大于n),我们没有必要每次都计算一个 log2(r-l+1) ,我们可以预处理出 1-n 的所有 log2 值,然后直接使用即可。
  • 代码如下:
#include 
using namespace std;
#define int long long

//快读
inline int read(){
     
    int x=0,f=1;char ch=getchar();
    while (!isdigit(ch)){
     if (ch=='-') f=-1;ch=getchar();}
    while (isdigit(ch)){
     x=x*10+ch-48;ch=getchar();}
    return x*f;
}
#define read read() //我觉得写俩括号太难了

const int maxn = 1e5+9;
int a[maxn],n,m;
int ans[maxn][30];
int Log[maxn];

void proc(){
     
    for(int i=1;i<=n;++i)
        ans[i][0] = read;
    for(int j=1;j<=Log[n];++j)
        for(int i=1;i+(1<<j)-1<=n;++i)
            ans[i][j] = max(ans[i][j-1],ans[i+(1<<(j-1))][j-1]);
}

int query(int l,int r){
     
    return max(ans[l][Log[r-l+1]],ans[r-(1<<Log[r-l+1])+1][Log[r-l+1]]);
}

signed main(){
     
    n=read,m=read;
    for(int i=1;i<=n;++i)
        Log[i] = log2(i);//预处理出log2
    proc();

    while(m--){
     
        int l=read,r=read;
        printf("%lld\n",query(l,r));
    }

    return 0;
}

时间大约提升了200ms

  • 此外还有另一种预处理 log2() 的方式,利用递推思想,避免调用内置的 log2() 函数:
    递推式为:
    我的算法不可能这么简单—ST表_第6张图片
  • 代码如下:
#include 
using namespace std;
#define int long long

//快读
inline int read(){
     
    int x=0,f=1;char ch=getchar();
    while (!isdigit(ch)){
     if (ch=='-') f=-1;ch=getchar();}
    while (isdigit(ch)){
     x=x*10+ch-48;ch=getchar();}
    return x*f;
}
#define read read() //我觉得写俩括号太难了

const int maxn = 1e5+9;
int a[maxn],n,m;
int ans[maxn][30];
int Log[maxn];

void proc(){
     
    for(int i=1;i<=n;++i)
        ans[i][0] = read;
    for(int j=1;j<=Log[n];++j)
        for(int i=1;i+(1<<j)-1<=n;++i)
            ans[i][j] = max(ans[i][j-1],ans[i+(1<<(j-1))][j-1]);
}

int query(int l,int r){
     
    return max(ans[l][Log[r-l+1]],ans[r-(1<<Log[r-l+1])+1][Log[r-l+1]]);
}

signed main(){
     
    n=read,m=read;

    for(int i=2;i<=n;++i)
        Log[i] = Log[i/2]+1;//递推式
    proc();

    while(m--){
     
        int l=read,r=read;
        printf("%lld\n",query(l,r));
    }

    return 0;
}

时间相对于 log2 的预处理提升了大约40ms,貌似意义不太大。

额外经验

P2251 质量检测

  • 求区间最小值,只需把上面的max换成min即可
#include 
using namespace std;
#define int long long
#define putlen putchar('\n')
//快读
inline int read(){
     
    int X=0; bool flag=1; char ch=getchar();
    while(ch<'0'||ch>'9') {
     if(ch=='-') flag=0; ch=getchar();}
    while(ch>='0'&&ch<='9') {
     X=(X<<1)+(X<<3)+ch-'0'; ch=getchar();}
    if(flag) return X;
    return ~(X-1);
}
#define read read()
//快输
inline void print(int x){
     
    if(x<0){
     putchar('-');x=-x;}
    if(x>9) print(x/10);
    putchar(x%10+'0');
}

int st[1000006][40];
int n,m;

int query(int l,int r){
     
    int k = log2(r-l+1);
    return min(st[l][k],st[r-(1<<k)+1][k]);
}

signed main(){
     
    n=read,m=read;
    for(int i=1;i<=n;++i) st[i][0] = read;

    for(int j=1;j<=log2(n);++j)
        for(int i=1;i<=n-(1<<j)+1;++i)
            st[i][j] = min(st[i][j-1],st[i+(1<<(j-1))][j-1]);

    for(int i=1;i<=n-m+1;++i){
     
        print(query(i,m+i-1));
        putlen;
    }

    return 0;
}

P2880 [USACO07JAN]Balanced Lineup G

  • 求区间最大值与最小值之差,建两个表就行了,一个维护最大,一个维护最小
#include 
using namespace std;
#define int long long
#define putlen putchar('\n')
//快读
inline int read(){
     
    int X=0; bool flag=1; char ch=getchar();
    while(ch<'0'||ch>'9') {
     if(ch=='-') flag=0; ch=getchar();}
    while(ch>='0'&&ch<='9') {
     X=(X<<1)+(X<<3)+ch-'0'; ch=getchar();}
    if(flag) return X;
    return ~(X-1);
}
#define read read()
//快输
inline void print(int x){
     
    if(x<0){
     putchar('-');x=-x;}
    if(x>9) print(x/10);
    putchar(x%10+'0');
}

int stMAX[50004][40];
int stMIN[50004][40];
int n,m;

int query(int l,int r){
     
    int k = log2(r-l+1);
    return max(stMAX[l][k],stMAX[r-(1<<k)+1][k]) - min(stMIN[l][k],stMIN[r-(1<<k)+1][k]);
}

signed main(){
     
    n=read,m=read;
    for(int i=1;i<=n;++i) stMAX[i][0] = stMIN[i][0] = read;

    for(int j=1;j<=log2(n);++j)
        for(int i=1;i<=n-(1<<j)+1;++i){
     
            stMIN[i][j] = min(stMIN[i][j-1],stMIN[i+(1<<(j-1))][j-1]);
            stMAX[i][j] = max(stMAX[i][j-1],stMAX[i+(1<<(j-1))][j-1]);
        }

    for(int i=1;i<=m;++i){
     
        int l=read,r=read;
        print(query(l,r));
        putlen;
    }

    return 0;
}

你可能感兴趣的:(我的算法不可能这么简单,算法,acm竞赛)