最近公共祖(LCA)模板_祖先深度_区域祖先_(欧拉序列+标准RMQ+四毛子)O(n)-O(1)

前几天看到LCA,发现有种O(n)-O(1)的做法。本想找板子学习一下,苦寻无果。那就自己写写板子吧!因为实在才疏学浅,各位大佬要是遇到bug,或是wa的题。请告诉我!我去改!

//duobly On - O1
//顶点1-n
#include
using namespace std;
typedef long long ll;

const int maxn = 3e5 + 7; // 32768k = 4m 内存只够1e6
const int maxlen = 7e4 + 7; // n / BASE
struct edge {
    int to, next, dist;
}e[2 * maxn];

struct LCA
{
    int n; // 顶点数
    int tol; // 边的数量
    int cnt; // dfs搜索数量
    int root; //树根
    int head[maxn]; // i的第一条边
    int fa[maxn]; // 父节点
    int efa[maxn]; //指向父节点的边
    int dis[maxn]; //结点到根的距离
    int dep[2 * maxn];
    int ver[2 * maxn];// 欧拉序列 前深度后顶点
    int ind[maxn]; // 顶点i在数组第一次出现的位置

    int BASE = 15; //分块长度
    int STlen; // st表分块后长度



    int sta[1 << 14][15][15]; // 状态查找
    int STsta[maxlen]; // 每块的状态
    int msn[23][maxlen]; // st表
    int lg[maxlen]; // logi 向下取整

    void pre()//预处理
    {
        lg[1] = 0;
        for(int i = 2; i < 60007; i ++)
        {
            lg[i] = lg[i - 1];
            if(!(i & (i - 1)))
                lg[i] ++;
        }
        for(int i = 0; i < (1 << (BASE - 1)); i ++)
        {
            for(int l = 0; l < BASE; l ++)
            {
                sta[i][l][l] = l;
                int now = 0, minv = 0;
                for(int r = l + 1; r < BASE; r ++)
                {
                    sta[i][l][r] = sta[i][l][r - 1];
                    if((1 << (r - 1)) & i)
                        now ++;
                    else
                    {
                        now --;
                        if (now < minv)
                        {
                            minv = now;
                            sta[i][l][r] = r;
                        }
                    }
                }
            }
        }

    }

    void init(int n, int root) //重置
    {
        this->n = n;
        this->root = root;
        tol = 0;
        cnt = 0;
        for(int i = 0; i <= n; i ++)
        {
            head[i] = fa[i] = ind[i] = -1;
        }
        fa[root] = -2;
        efa[root] = -1;
        dis[root] = 0;
    }

    void addedge (int u, int v, int d) //加边
    {
        e[tol].to = v; e[tol].dist = d;
        e[tol].next = head[u]; head[u] = tol++;

        e[tol].to = u; e[tol].dist = d;
        e[tol].next = head[v]; head[v] = tol++;
    }

    void dfs(int u, int d)
    {
        dep[cnt] = d;
        ver[cnt] = u;

        if(ind[u] == -1)
        {
            ind[u] = cnt;
        }
        cnt++;
        if(head[u] == -1)
            return;
        for(int i = head[u]; i != -1; i = e[i].next)
        {
            int temp = e[i].to;
            if(fa[temp] == -1)
            {
                fa[temp] = u;
                efa[temp] = i;
                dis[temp] = e[i].dist + dis[u];
                dfs(temp, d + 1);
                dep[cnt] = d;
                ver[cnt++] = u;
            }
        }
    }

    void work() //输入边后处理
    {
        dfs(root, 0);
        STlen = (2 * n - 2) / BASE + 1;

        for(int i = 0; i < 2 * n - 1; i ++)
        {
            if(i % BASE == 0) //块首
            {
                msn[0][i / BASE] = i; //块内最值地址
                STsta[i / BASE] = 0; //状态序列
            }
            else
            {
                if(dep[i] < dep[msn[0][i / BASE]])
                    msn[0][i / BASE] = i; // 更新地址
                if(dep[i] > dep[i - 1])
                    STsta[i / BASE] |= 1 << (i % BASE - 1); //大于差分序列为1
            }
        }

        for(int j = 1; (1 << j) <= STlen; j ++) // j < lgn
        {
            for(int i = 0; i + (1 << j) - 1 < STlen; i ++)
            {
                int b1 = msn[j - 1][i], b2 = msn[j - 1][i + (1 << (j - 1))];
                msn[j][i] = dep[b1] < dep[b2]? b1 : b2;
            }
        }
    }

    int querymin(int L, int R)  //返回位置
    {
        int idl = L / BASE, idr = R / BASE;
        if(idl == idr)
            return idl * BASE + sta[STsta[idl]][L % BASE][R % BASE];
        else
        {
            int b1 = idl * BASE + sta[STsta[idl]][L % BASE][BASE - 1];
            int b2 = idr * BASE + sta[STsta[idr]][0][R % BASE];
            int buf = dep[b1] < dep[b2]? b1 : b2;

            if(idr - idl - 1)
            {
                int c = lg[idr - idl - 1];
                int b1 = msn[c][idl + 1];
                int b2 = msn[c][idr - (1 << c)];
                int b = dep[b1] < dep[b2]? b1 : b2;
                return dep[buf] < dep[b]? buf : b;
            }
            return buf;
        }
    }

    int lcadep(int x, int y) //祖先所在深度
    {
        int L = ind[x];
        int R = ind[y];
        if(L > R)
            swap(L, R);
        return dep[querymin(L, R)];

    }

    int lcaver(int x, int y) //祖先顶点
    {
        int L = ind[x];
        int R = ind[y];
        if(L > R)
            swap(L, R);
        return ver[querymin(L, R)];
    }

/*
    void pri()
    {
        cout<
///*
    int bin[23];
    int maxx[23][maxn];
    int minx[23][maxn];
    int lgg[maxn];
    void rangework()
    {
        ST_buildmin(n);
        ST_buildmax(n);
    }

    int rangever(int x, int y) //区间lca查询
    {
        int L = ST_min(x, y);
        int R = ST_max(x, y);
        if(L > R)
            swap(L, R);
        return ver[querymin(L, R)];

    }
    void ST_buildmin(int n)
    {
        lgg[0] = -1;
        for(int i = 1; i <= n; i++)
            lgg[i] = lgg[i / 2] + 1;
        bin[0] = 1;
        for(int j = 1; j < 23; j++)
            bin[j] = bin[j - 1] * 2;
        for(int i = 1; i <= n; i++)
        {
            minx[0][i] = ind[i];
        }
        for(int i = 1; i <= lgg[n]; i++)
            for(int j = 1; j + bin[i] - 1 <= n; j++)
                    minx[i][j] = min(minx[i - 1][j], minx[i - 1][j + bin[i - 1]]);
    }
    void ST_buildmax(int n)
    {
        for(int i = 1; i <= n; i++)
        {
            maxx[0][i] = ind[i];
        }
        for(int i = 1; i <= lgg[n]; i++)
            for(int j = 1; j + bin[i] - 1 <= n; j++)
                    maxx[i][j] = max(maxx[i - 1][j], maxx[i - 1][j + bin[i - 1]]);
    }

    int ST_min(int x, int y)
    {
        if(x > y)
            swap(x, y);
        int temp = lgg[y - x + 1];

        return min(minx[temp][x], minx[temp][y - bin[temp] + 1]);
    }
    int ST_max(int x, int y)
    {
        if(x > y)
            swap(x, y);
        int temp = lgg[y - x + 1];
        return max(maxx[temp][x], maxx[temp][y - bin[temp] + 1]);
    }
//*/
}lca;

int n, m;
int main()
{
    lca.pre(); //预处理
    while(scanf("%d", &n) == 1)
    {
        int u, v, d;
        lca.init(n, 1); //重置
        for(int i = 0; i < n - 1; i++)
        {
            scanf("%d%d", &u, &v);
            lca.addedge(u, v, 0); //添加边
        }
        lca.work(); //lca处理
        lca.rangework(); //区域lca处理
        //lca.lcadep(x, y); //祖先深度
        //lca.lcaver(x, y); //祖先结点
        scanf("%d", &m);
        int x, y;
        for(int i = 1; i <=  m; i++)
        {
            scanf("%d%d", &x, &y);
            printf("%d\n", lca.rangever(x, y));
        }
    }

}

你可能感兴趣的:(随笔)