Codeforces-983E (Round483 Div1) NN country 树上倍增+树状数组

大意:有一棵树,有m条线路,给出q个询问u->v最少通过几条线路到达。
解法:贪心求u和v到lca最少跳几次,然后判最后一次是否通过同一条线路到达。通过倍增法到达logN级别,最后的特判是用树状数组达到logN级别。

对于每一个节点,其子树代表的区间为dfn[v]到dfn[v]+sz[v],对于需要特判的一组(u,v),在dfs到u的时候把起点为u的路的终点+1,相当于有一条路从u戳到某一个区间(即某一个节点的子树中),最后判断v的子树中是否有从u的子树中戳过来的路,如果有那么就存在u->v的一条线路。对于哪些(u,v)是需要特判的呢?因为程序中求的是u和v还差一次才跳到lca的次数a和b,因此如果u或者v本身就是lca那么就不需要特判。

 


#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;
#define lc n << 1
#define rc n << 1 | 1
typedef long long ll;
typedef unsigned long long ull;
const int maxint = ~0U >> 1;


const int maxn = 2e5 + 7;
const int INF = 0x3f3f3f3f;
const int logN = 20;
struct edge {
int next, to;
}e[maxn];


int n, m, q;
int dfn[maxn], dfed[maxn], anc[maxn][logN], dep[maxn], tot;//lca
int head[maxn], cnt;//road
int top[maxn];//the furthest ancestor
int jump[maxn][logN];
bool bad[maxn];
int ans[maxn],tmpcnt[maxn];
vectorwayed[maxn];
vector >query[maxn];
inline void add(int x, int y) 
{
e[cnt].to = y;
e[cnt].next = head[x];
head[x] = cnt++;
}
void dfs1(int n, int fa = -1)
{
dfn[n] = ++tot;
dep[n] = (fa == -1 ? 0 : dep[fa] + 1);
anc[n][0] = fa;
for (int i = 1;i < logN;i++)
{
if (anc[n][i - 1] == -1)
anc[n][i] = -1;
else
anc[n][i] = anc[anc[n][i - 1]][i - 1];
}
for (int i = head[n];~i;i = e[i].next)
{
int v = e[i].to;
dfs1(v, n);
}
dfed[n] = tot;
}
bool is_ancestor(int u, int v)//if u is the ancestor of v
{
if (u == -1)return true;
if (v == -1)return false;
return dfn[u] <= dfn[v] && dfed[u] >= dfed[v];
}
int lca(int u, int v)
{
if (is_ancestor(u, v))return u;
if (is_ancestor(v, u))return v;
for (int i = logN - 1;i >= 0;i--)
{
if (!is_ancestor(anc[u][i], v))
u = anc[u][i];
}
return anc[u][0];
}
int calc_top(int n)//求一次最远跳到的点top
{
for (int i = head[n];~i;i = e[i].next)
{
top[n] = min(top[n], calc_top(e[i].to));
}
return top[n];
}
void dfs2(int n)//倍增往上跳
{
jump[n][0] = top[n];
for (int i = 1;i < logN;i++)
jump[n][i] = jump[jump[n][i - 1]][i - 1];
for (int i = head[n];~i;i = e[i].next)
dfs2(e[i].to);
}
bool reachable(int u, int v)
{
int _lca = lca(u, v);
for (int i = logN - 1;i >= 0;i--)
{
if (!is_ancestor(jump[u][i], _lca))
u = jump[u][i];
if (!is_ancestor(jump[v][i], _lca))
v = jump[v][i];
}
u = jump[u][0], v = jump[v][0];
return is_ancestor(u, _lca) && is_ancestor(v, _lca);
}
int calc_cnt(int& u, int v)
{
int cnt = 0;
for (int i = logN - 1;i >= 0;i--)
{
if (!is_ancestor(jump[u][i], v))
{
u = jump[u][i];
cnt += (1 << i);
}
}
return cnt;
}


int a[maxn];
inline int lowbit(int x)
{
return x & -x;
}
void add(int x)
{
while (x <= n)
{
a[x]++;
x += lowbit(x);
}
}
int sum(int x)
{
int ret = 0;
while (x > 0)
{
ret += a[x];
x -= lowbit(x);
}
return ret;
}
void solve(int n)
{
for (auto qry : query[n])
{
int v = qry.first;
tmpcnt[qry.second] += sum(dfed[v]) - sum(dfn[v] - 1);
}
for (auto ed : wayed[n])
add(dfn[ed]);
for (int i = head[n];~i;i = e[i].next)
solve(e[i].to);
for (auto qry : query[n])
{
int v = qry.first, id = qry.second;
tmpcnt[id] -= sum(dfed[v]) - sum(dfn[v] - 1);
if (tmpcnt[id] <= -1)
ans[id]--;
}
}
int main()
{
int u, v, _lca, tmp;
memset(head, -1, sizeof head);
memset(anc, -1, sizeof anc);
memset(jump, -1, sizeof jump);
scanf("%d", &n);
for (int i = 2;i <= n;i++)
{
scanf("%d", &tmp);
add(tmp, i);
}
dfs1(1);//precalc the lca
for (int i = 1;i <= n;i++)
top[i] = i;
scanf("%d", &m);
for (int i = 0;i < m;i++)
{
scanf("%d%d", &u, &v);
wayed[u].push_back(v);
wayed[v].push_back(u);
_lca = lca(u, v);
if (dep[top[v]] > dep[_lca])
top[v] = _lca;
if (dep[top[u]] > dep[_lca])
top[u] = _lca;
}
calc_top(1);
dfs2(1);

#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;
#define lc n << 1
#define rc n << 1 | 1
typedef long long ll;
typedef unsigned long long ull;
const int maxint = ~0U >> 1;


const int maxn = 2e5 + 7;
const int INF = 0x3f3f3f3f;
const int logN = 20;
struct edge {
int next, to;
}e[maxn];


int n, m, q;
int dfn[maxn], dfed[maxn], anc[maxn][logN], dep[maxn], tot;//lca
int head[maxn], cnt;//road
int top[maxn];//the furthest ancestor
int jump[maxn][logN];
bool bad[maxn];
int ans[maxn],tmpcnt[maxn];
vectorwayed[maxn];
vector >query[maxn];
inline void add(int x, int y) 
{
    e[cnt].to = y;
    e[cnt].next = head[x];
    head[x] = cnt++;
}
void dfs1(int n, int fa = -1)
{
    dfn[n] = ++tot;
    dep[n] = (fa == -1 ? 0 : dep[fa] + 1);
    anc[n][0] = fa;
    for (int i = 1;i < logN;i++)
    {
        if (anc[n][i - 1] == -1)
            anc[n][i] = -1;
        else
            anc[n][i] = anc[anc[n][i - 1]][i - 1];
    }
    for (int i = head[n];~i;i = e[i].next)
    {
        int v = e[i].to;
        dfs1(v, n);
       }
    dfed[n] = tot;
}
bool is_ancestor(int u, int v)//if u is the ancestor of v
{
    if (u == -1)return true;
    if (v == -1)return false;
    return dfn[u] <= dfn[v] && dfed[u] >= dfed[v];
}
int lca(int u, int v)
{
    if (is_ancestor(u, v))return u;
    if (is_ancestor(v, u))return v;
    for (int i = logN - 1;i >= 0;i--)
    {
        if (!is_ancestor(anc[u][i], v))
        u = anc[u][i];
    }
    return anc[u][0];
}
int calc_top(int n)//求一次最远跳到的点top
{
    for (int i = head[n];~i;i = e[i].next)
    {
        top[n] = min(top[n], calc_top(e[i].to));
    }
    return top[n];
}
void dfs2(int n)//倍增往上跳
{
    jump[n][0] = top[n];
    for (int i = 1;i < logN;i++)
        jump[n][i] = jump[jump[n][i - 1]][i - 1];
    for (int i = head[n];~i;i = e[i].next)
        dfs2(e[i].to);
}
bool reachable(int u, int v)
{
    int _lca = lca(u, v);
    for (int i = logN - 1;i >= 0;i--)
    {
        if (!is_ancestor(jump[u][i], _lca))
            u = jump[u][i];
        if (!is_ancestor(jump[v][i], _lca))
            v = jump[v][i];
    }
    u = jump[u][0], v = jump[v][0];
    return is_ancestor(u, _lca) && is_ancestor(v, _lca);
}
int calc_cnt(int& u, int v)
{
    int cnt = 0;
    for (int i = logN - 1;i >= 0;i--)
    {
        if (!is_ancestor(jump[u][i], v))
        {
            u = jump[u][i];
            cnt += (1 << i);
        }
    }
    return cnt;
}


int a[maxn];
inline int lowbit(int x){return x & -x;}
void add(int x)
{
    while (x <= n)
    {
        a[x]++;
        x += lowbit(x);
    }
}
int sum(int x)
{
    int ret = 0;
    while (x > 0)
    {
        ret += a[x];
        x -= lowbit(x);
    }
    return ret;
}
void solve(int n)
{
    for (auto qry : query[n])
    {
        int v = qry.first;
        tmpcnt[qry.second] += sum(dfed[v]) - sum(dfn[v] - 1);
    }
    for (auto ed : wayed[n])
    add(dfn[ed]);
    for (int i = head[n];~i;i = e[i].next)
        solve(e[i].to);
    for (auto qry : query[n])
    {
        int v = qry.first, id = qry.second;
        tmpcnt[id] -= sum(dfed[v]) - sum(dfn[v] - 1);
        if (tmpcnt[id] <= -1)
            ans[id]--;
    }
   }
int main()
{
    int u, v, _lca, tmp;
    memset(head, -1, sizeof head);
    memset(anc, -1, sizeof anc);
    memset(jump, -1, sizeof jump);
    scanf("%d", &n);
    for (int i = 2;i <= n;i++)
    {
        scanf("%d", &tmp);
        add(tmp, i);
    }
    dfs1(1);//precalc the lca
    for (int i = 1;i <= n;i++)
        top[i] = i;
    scanf("%d", &m);
    for (int i = 0;i < m;i++)
    {
        scanf("%d%d", &u, &v);
        wayed[u].push_back(v);
        wayed[v].push_back(u);
        _lca = lca(u, v);
        if (dep[top[v]] > dep[_lca])
            top[v] = _lca;
        if (dep[top[u]] > dep[_lca])
            top[u] = _lca;
    }
    calc_top(1);
    dfs2(1);
    scanf("%d", &q);
    for(int i = 0;i < q;i++)
    {
        scanf("%d%d", &u, &v);
        if (!reachable(u, v))bad[i] = true;
        _lca = lca(u, v);
        ans[i] = calc_cnt(u, _lca) + calc_cnt(v, _lca);
        if (u != _lca)ans[i]++;
        if (v != _lca)ans[i]++;
        if (u != _lca && v != _lca)query[u].push_back({ v,i });
    }
    solve(1);
    for (int i = 0;i < q;i++)
    {
        printf("%d\n", bad[i] ? -1 : ans[i]);
    }
    return 0;
}


scanf("%d", &q);
for(int i = 0;i < q;i++)
{
scanf("%d%d", &u, &v);
if (!reachable(u, v))bad[i] = true;
_lca = lca(u, v);
ans[i] = calc_cnt(u, _lca) + calc_cnt(v, _lca);
if (u != _lca)ans[i]++;
if (v != _lca)ans[i]++;
if (u != _lca && v != _lca)query[u].push_back({ v,i });
}
solve(1);
for (int i = 0;i < q;i++)
{
printf("%d\n", bad[i] ? -1 : ans[i]);
}
return 0;
}

 

你可能感兴趣的:(ACM学习)