【模版】莫队(带修改莫队、树上莫队)算法


莫队算法


莫队算法是 离线 解决 区间查询 问题的一种高效暴力算法。
前置知识点: ① ① 分块 ② ② s o r t sort sort关键字排序 ③ ③ l c a lca lca查询(树上莫队)
时间复杂度: O ( n ∗ n ) O(n*\sqrt{n}) O(nn )
空间复杂度: O ( n + n ) O(n+\sqrt{n}) O(n+n )


算法流程:
一、确定分块大小:每块大小为 n \sqrt{n} n 。确定每个元素属于哪一块。

int part = sqrt(n);
int num = ceil((double)n/part);
for(int i = 1; i <= num; i++)
    for(int j = (i-1)*part+1; j <= i*part; j++)
        bel[j] = i;

二、用结构体存储询问并用 s o r t sort sort 排序:
对于每次询问的 左端点 l l l,按照 所在块 从小到大排序。
若左区间 l l l 所在块 相同,则按照 右端点 从小到大排序。

bool cmp(pro a, pro b)  {return bel[a.l] == bel[b.l] ? a.r < b.r : bel[a.l] < bel[b.l];}

(这样保证复杂度最优,左端点按块处理,右端点按从小到大的顺序处理)
三、令左端点为 1 1 1 ,右端点为 0 0 0 。挪动左右端点符合第一次询问。
每移动一个位置的端点,都需要根据题意来修改答案。当左右端点分别对应
当前询问的左右区间时,记录当前询问的答案。

额外说明:
l l l r r r 为当前的查询区间的左右端点, L L L R R R为需要查询区间的左右端点。

1 1 1.当左端点向右移动时,对答案的贡献是负的。
①在 l l l 向右移动 逼近 L L L 时,每移动一位,更新一次答案,注意是:先更新当前位置答案,然后向右移动。
②最终我们的更新答案的区间是 [ l , L − 1 ] [l,L-1] [l,L1]
L L L 这个位置我们并没有更新答案,在 l l l 向右移动时,我们只更新无关查询区间的答案。

2 2 2.当右端点向右移动时,对答案的贡献是正的。
①在 r r r 向右移动 逼近 R R R 时,每移动一位,更新一次答案,注意是:先向右移动,然后再更新新位置答案。
②最终我们更新答案的区间是 [ r , R ] [r,R] [r,R]
R R R 这个位置我们更新了答案,与上述 左端点向右移动时 相照应,左端点减,右端点加。多减的会被多加 进行补正, 最终两者交集区间即为 区间 [ L , R ] [L,R] [L,R] 的答案。

3 3 3.当左端点向左移动时,对答案的贡献是正的。
不再赘述,先移动,再更新答案,补上了 [ L , l − 1 ] [L,l-1] [L,l1] 部分的答案。

4 4 4.当右端点向左移动时,对答案的贡献是负的。
不再赘述,先更新答案,再移动,删去了 [ R + 1 , r ] [R+1,r] [R+1,r] 部分的答案。

先更新还是先移动,注意顺序。
可以用画图的方法 判断先后顺序。


莫队例题:[国家集训队]小Z的袜子
例题代码:

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;
int n,m;
int col[200200],bel[200200],le[200200],ri[200200],cnt[200200],ans[202020];
struct pro
{
    int l,r,in;
}q[200200];
int read()
{
    int rt = 0, in = 1; char ch = getchar();
    while(ch < '0' || ch > '9') {if(ch == '-') in = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9') {rt = rt * 10 + ch - '0'; ch = getchar();}
    return rt * in;
}
bool cmp(pro a, pro b)  {return bel[a.l] == bel[b.l] ? a.r < b.r : bel[a.l] < bel[b.l];}
int gcd(int a, int b)   {return b == 0 ? a : gcd(b, a%b);}
int calc(int x) {return x * (x-1) / 2;}
int main()
{
    n = read(), m = read();
    for(int i = 1; i <= n; i++) col[i] = read(); 
    for(int i = 1; i <= m; i++) 
    {
        le[i] = read(), ri[i] = read();
        q[i].l = le[i], q[i].r = ri[i], q[i].in = i;
    }
    int part = sqrt(n);
    int num = ceil((double)n/part);
    for(int i = 1; i <= num; i++)
        for(int j = (i-1)*part+1; j <= i*part; j++)
            bel[j] = i;
    sort(q+1, q+1+m, cmp);
    int l = 1, r = 0, Ans = 0;
    for(int i = 1; i <= m; i++)
    {
        int L = q[i].l, R = q[i].r;
        while(l < L)
        {
            Ans += calc(cnt[col[l]]-1) - calc(cnt[col[l]]);
            cnt[col[l++]]--;
        }
        while(l > L)
        {
            cnt[col[--l]]++;
            Ans += calc(cnt[col[l]]) - calc(cnt[col[l]]-1);
        }
        while(r < R)
        {
            cnt[col[++r]]++;
            Ans += calc(cnt[col[r]]) - calc(cnt[col[r]]-1);
        }
        while(r > R)
        {
            Ans += calc(cnt[col[r]]-1) - calc(cnt[col[r]]);
            cnt[col[r--]]--;
        }
        ans[q[i].in] = Ans;
    }
    for(int i = 1; i <= m; i++)
    {
        int a = ans[i], b = calc(ri[i] - le[i] + 1);
        if(a == 0)
        {
            printf("0/1\n");
            continue;
        }
        int k = gcd(a, b);
        a /= k, b /= k;
        printf("%d/%d\n",a,b);
    }
    system("pause");
    return 0;
}

莫队(带修改)


在莫队的基础上,增添了修改操作。
[国家集训队]数颜色/维护队列
修改和查询交错进行。
难点在于,查询操作是排序后的,我们无法保证某次查询是在哪些次修改后。
原做法是 只要固定好左右区间,就可以记录答案。


带修改莫队延伸出:
一、把修改操作也当成端点移动,当左右端点对应查询的左右端点时,再调整时间顺序,使之时间上也符合当次查询时的区间。因此需要调整 s o r t sort sort 排序。分块大小为 n 2 / 3 n^{2/3} n2/3。证明不会。
①按 左端点所在块 从小到大排序。
②左端点所在块 相同时,按 右端点所在块(也可以是右端点,两者几乎无区别) 从小到大排序。
③右端点所在块 也相同时,按 查询的时间顺序 从小到大排序

if(bel[a.l] == bel[b.l])    return bel[a.r] == bel[b.r] ? a.time < b.time : a.r < b.r;
else    return bel[a.l] < bel[b.l];

二、对于每次的修改操作,当我们进行修改时,我们不能进行简单的覆盖,而是要交换 修改前后的值(我们需要把修改前后的信息都存储下来)。因为我们以后复原的时候 还需要调整回来。

三、记录当前时间的变量 t i m e time time 向右(新时间)趋近时,先自增,然后更新答案。记录当前时间的变量 t i m e time time 向左(老时间)趋近时,先更新答案,再自减。因为 初始值为 0 0 0 ,视为当前所在时间已经记录。

例题代码:

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;
int n, m, cntq, cntm;
int col[2002000], bel[2002000], cnt[2002000], ans[2002000];
struct query
{
    int l, r, time, id;
}q[2002000];
struct modify
{
    int pos, val, last;
}p[2002000];
int read()
{
    int rt = 0, in = 1; char ch = getchar();
    while(ch < '0' || ch > '9') {if(ch == '-') in = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9') {rt = rt * 10 + ch - '0'; ch = getchar();}
    return rt * in;
}
bool cmp(query a, query b)
{
    if(bel[a.l] == bel[b.l])    return bel[a.r] == bel[b.r] ? a.time < b.time : a.r < b.r;
    else    return bel[a.l] < bel[b.l];
}
int main()
{
    n = read(), m = read();
    for(int i = 1; i <= n; i++) col[i] = read();
    for(int i = 1; i <= m; i++)
    {
        char ch;
        cin >> ch;
        if(ch == 'Q')
        {
            ++cntq;
            q[cntq].l = read(), q[cntq].r = read(), q[cntq].id = cntq, q[cntq].time = cntm;
        }
        if(ch == 'R')
        {
            ++cntm;
            p[cntm].pos = read(), p[cntm].val = read();
        }
    }
    int part = pow(n, 2.0 / 3.0);
    int num = ceil((double)n / part);
    for(int i = 1; i <= num; i++)
        for(int j = (i-1)*part+1; j <= i*part; j++)
            bel[j] = i;
    sort(q+1, q+1+cntq, cmp);
    int l = 1, r = 0, Ans = 0, time = 0;
    for(int i = 1; i <= cntq; i++)
    {
        int L = q[i].l, R = q[i].r, T = q[i].time;
        while(l < L)
        {
            cnt[col[l]]--;
            if(cnt[col[l]] == 0)    Ans--;
            l++;
        }
        while(l > L)
        {
            l--;
            cnt[col[l]]++;
            if(cnt[col[l]] == 1)    Ans++;
        }
        while(r < R)
        {
            r++;
            cnt[col[r]]++;
            if(cnt[col[r]] == 1)    Ans++;
        }
        while(r > R)
        {
            cnt[col[r]]--;
            if(cnt[col[r]] == 0)    Ans--;
            r--;
        }
        while(time < T)
        {
            time++;
            if(L <= p[time].pos && p[time].pos <= R)
            {
                cnt[col[p[time].pos]]--;
                if(cnt[col[p[time].pos]] == 0)  Ans--;
                cnt[p[time].val]++;
                if(cnt[p[time].val] == 1)  Ans++;
            }
            swap(col[p[time].pos], p[time].val);
        }
        while(time > T)
        {
            if(L <= p[time].pos && p[time].pos <= R)
            {
                cnt[col[p[time].pos]]--;
                if(cnt[col[p[time].pos]] == 0)  Ans--;
                cnt[p[time].val]++;
                if(cnt[p[time].val] == 1)  Ans++;
            }
            swap(col[p[time].pos], p[time].val);
            time--;
        }
        ans[q[i].id] = Ans;
    }
    for(int i = 1; i <= cntq; i++)  printf("%d\n",ans[i]);
    system("pause");
    return 0;0000000
}

树上莫队


将节点记录成一个序列,树上问题就转化成了序列问题。
于是就转成了普通莫队的情况。


如何找到一个合适的序列将树转化成序列呢?
普通的 d f s dfs dfs 序 解决不了问题,这时需要引入 欧拉序 的概念。
欧拉序是一种特殊的 d f s dfs dfs 序,在遍历某点 的前和后,均记录 当前的节点。
这样该序列长度为 2 ∗ n 2*n 2n 。 每个点会被记录两次。
f i r [ ] fir[] fir[] l s t [ ] lst[] lst[] 两个数组分别记录每个节点在序列中第一次出现和最后一次出现的位置。
v i s [ ] vis[] vis[] 表示该点之前是否被访问过


算法过程:

设树上两点 u u u v v v l c a lca lca w w w,且 f i r [ u ] < = f i r [ v ] fir[u] <= fir[v] fir[u]<=fir[v]

①若 w w w 是点 u u u (两点处在一条链上)。
那么 路径 u − > v u->v u>v 所经历的点 在 下标为 [ f i r [ u ] , f i r [ v ] ] [fir[u],fir[v]] [fir[u],fir[v]] 的子串中。
更新答案:① 出现两次的点贡献为 0 0 0 ,因为这类点 并不是 路径 u − > v u->v u>v 所经历的点。②路径 u − > v u->v u>v 上的点在子串中只出现了一次,贡献分别计算。

②若 w w w 不是点 u u u (两点不处在同一条链上)。
则这时候取的子串下标对应的是 [ l s t [ u ] , f i r [ v ] ] [lst[u],fir[v]] [lst[u],fir[v]] 。并存储这两点的 l c a lca lca,因为两点路径上只有 l c a lca lca 不处于这段子串内。 左区间端点不必取 f i r [ u ] fir[u] fir[u] , 因为 [ f i r [ u ] , l s t [ u ] − 1 ] [fir[u], lst[u]-1] [fir[u],lst[u]1] 这段区间是多余的,这段区间的点(除了u)都出现了两次,对答案没有贡献,可以删掉。对于第②类的询问, l c a lca lca 是需要额外添加的,更新答案后要将 l c a lca lca v i s vis vis 标记还原(因为 l c a lca lca 并不处于一个连续子串内)。


例题:Count on a tree II
例题代码:

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;
int n,m,depth,Ans;
int fa[404040][20],fir[404040],lst[404040],ans[404040],bel[404040],cnt[404040];
int head[404040],bas[404040],val[404040],tem[404040],deep[404040];
bool vis[400400];
struct list
{
    int to, nxt;
}e[402020];
struct query
{
    int l, r, lca, id;
}q[404040];
int read()
{
    int rt = 0, in = 1; char ch = getchar();
    while(ch < '0' || ch > '9') {if(ch == '-') in = -1; ch = getchar();}
    while(ch >= '0' && ch <= '9') {rt = rt * 10 + ch - '0'; ch = getchar();}
    return rt * in;
}
bool cmp(query a, query b)  {  return bel[a.l] == bel[b.l] ? a.r < b.r : bel[a.l] < bel[b.l];  }
void add_edge(int u, int v)
{
    e[++head[0]].to = v;
    e[head[0]].nxt = head[u];
    head[u] = head[0];
}
void dfs(int x)
{
    bas[++bas[0]] = x;
    fir[x] = bas[0];
    for(int i = head[x]; i; i = e[i].nxt)
    {
        if(deep[e[i].to]) continue;
        deep[e[i].to] = deep[x] + 1;
        fa[e[i].to][0] = x;
        dfs(e[i].to);
    }
    bas[++bas[0]] = x;
    lst[x] = bas[0];
}
int query_lca(int u, int v)
{
    if(deep[u] > deep[v])   swap(u, v);
    int d = deep[v] - deep[u];
    for(int i = depth; i >= 0; i--)
        if( (1 << i) & d)
            v = fa[v][i];
    if(u == v)  return v;
    for(int i = depth; i >= 0; i--)
        if(fa[u][i] != fa[v][i])
            u = fa[u][i], v = fa[v][i];
    return fa[u][0];
}
void work(int pos)
{
    if(!vis[pos])
    {
        vis[pos] = 1;
        cnt[val[pos]]++;
        if(cnt[val[pos]] == 1)  Ans++;
    }
    else if(vis[pos])
    {
        vis[pos] = 0;
        cnt[val[pos]]--;
        if(cnt[val[pos]] == 0)  Ans--;
    }
}
int main()
{
    n = read(), m = read();
    depth = log(n) / log(2) + 1;
    
    for(int i = 1; i <= n; i++) val[i] = tem[i] = read();
    sort(tem+1, tem+1+n);
    int len = unique(tem+1, tem+1+n) - tem - 1;
    for(int i = 1; i <= n; i++) val[i] = lower_bound(tem+1, tem+1+len, val[i]) - tem;
    for(int i = 1; i < n; i++)
    {
        int u = read(), v = read();
        add_edge(u, v); add_edge(v, u);
    }
    deep[1] = 1;
    dfs(1);
    for(int i = 1; i <= depth; i++)
        for(int j = 1; j <= n; j++)
            fa[j][i] = fa[fa[j][i-1]][i-1];
    int part = sqrt(bas[0]);
    int num = ceil((double)bas[0] / part);
    for(int i = 1; i <= num; i++)
        for(int j = (i-1)*part+1; j <= i*part; j++)
            bel[j] = i;
    for(int i = 1; i <= m; i++)
    {
        int l = read(), r = read();
        int lca = query_lca(l, r);
        if(fir[l] > fir[r]) swap(l, r);
        q[i].id = i;
        if(l == lca)    q[i].l = fir[l], q[i].r = fir[r];
        else    q[i].l = lst[l], q[i].r = fir[r], q[i].lca = lca;
    }
    sort(q+1, q+1+m, cmp);
    int l = 1, r = 0;
    for(int i = 1; i <= m; i++)
    {
        int L = q[i].l, R = q[i].r, lca = q[i].lca;
        while(l < L)    work(bas[l++]);
		while(l > L)    work(bas[--l]);
		while(r < R)    work(bas[++r]);
		while(r > R)    work(bas[r--]);
        if(lca) work(lca);
        ans[q[i].id] = Ans;
        if(lca) work(lca);
    }
    for(int i = 1; i <= m; i++) printf("%d\n",ans[i]);
    system("pause");
    return 0;
}

你可能感兴趣的:(算法【模版】)