O ( m l o g n ) O(mlogn) O(mlogn),n为节点数量,m为询问次数,lca是一种在线处理询问的算法
自己也是自己的祖先
倍增:
f a ( i , j ) fa(i, j) fa(i,j)表示从i开始,向上走 2 j 2^j 2j步走到的点
j = 0,走到父节点
j > 0,分两步走,先走到 2 j − 1 2^{j-1} 2j−1步再走 2 j − 1 2^{j-1} 2j−1步,那么一共就会走 2 j 2^j 2j步, f a ( i , j ) = f a ( f a ( i , j − 1 ) , j − 1 ) fa(i, j) = fa(fa(i, j-1), j-1) fa(i,j)=fa(fa(i,j−1),j−1)
d e p t h ( i ) depth(i) depth(i)表示层数
第一步:基于二进制的思想,x和y之间的层数差距为 d e p t h ( x ) − d e p t h ( y ) depth(x) - depth(y) depth(x)−depth(y),假设y的层数小于x的层数,此时x要往上跳
若要跳k层,那么根据k的二进制表示将k拆分成多个2的幂相加,由于我们已经预处理了 f ( i , j ) f(i, j) f(i,j),所以根据 f ( i , j ) f(i, j) f(i,j)的值往上跳即可
当 f ( x , k ) = f ( y , k ) f(x, k) = f(y, k) f(x,k)=f(y,k)时,即x往上跳k步和y往上跳k步后,位于同一个位置,此时找到了一个公共祖先,但不是最近公共祖先,所以这里要减小k的值,直到 f ( x , k ) ≠ f ( y , k ) f(x, k) ≠ f(y, k) f(x,k)=f(y,k),此时才找到了最近公共祖先
规定 d e p t h ( 0 ) = 0 depth(0) = 0 depth(0)=0,0号点为哨兵,位于第0层,根节点位于第一层
向上跳 2 k 2^k 2k步后,若跳出了这颗树,那么 f ( i , k ) = 0 f(i, k) = 0 f(i,k)=0
预处理depth和fa的板子:
void bfs(int root)
{
memset(dis, 0x3f, sizeof(dis));
q[tt ++ ] = root;
depth[root] = 1, depth[0] = 0;
while (tt >= hh)
{
int x = q[hh ++ ];
for (int i = h[x]; i != -1; i = ne[i])
{
int y = e[i];
if (depth[y] > depth[x] + 1)
{
depth[y] = depth[x] + 1;
q[tt ++ ] = y;
fa[y][0] = x;
for (int k = 1; k <= c; ++ k )
fa[y][k] = fa[fa[y][k - 1]][k - 1];
}
}
}
}
lca板子:
int lca(int x, int y)
{
if (depth[x] < depth[y]) swap(x, y);
for (int k = c; k >= 0; -- k )
if (depth[fa[x][k]] >= depth[y])
x = fa[x][k];
if (x == y) return y;
for (int k = c; k >= 0; -- k )
if (fa[x][k] != fa[y][k])
{
x = fa[x][k];
y = fa[x][k];
}
return fa[x][0];
}
O ( n + m ) O(n + m) O(n+m) ,n为节点数量,m为询问次数
tarjan是一种离线处理询问的算法,是向上标记法的优化
在深度优先遍历时,将所有点分成三大类
将已经遍历且回溯的点标记成2,正在遍历的点标记成1,未遍历的点标记成0
对于正在回溯的点,需要处理所有有关该点的询问信息:若询问的另外一个点已经遍历过(回溯完成),那么该点将被分到一个集合中,集合的代表点就是两点的最近公共祖先
比如上图,当前正在遍历x这个点,已经遍历完的点为绿色部分,这些点所属集合的代表点位于红色分支上
模板:
// 用dfs维护dis数组
void dfs(int x, int fa)
{
for (int i = h[x]; i != -1; i = ne[i])
{
int y = e[i];
if (y != fa)
{
dis[y] = dis[x] + w[i];
dfs(y, x);
}
}
}
// tarjan处理询问
void tarjan(int x)
{
st[x] = 1; // 当前正在遍历该点
for (int i = h[x]; i != -1; i = ne[i])
{
int y = e[i];
if (!st[y]) // 当前为遍历ydian
{
tarjan(y);
p[y] = u;
}
}
for (auto item : query[x]) // 处理有关当前节点的询问
{
int y = item.first, id = item.first; // id为该询问的唯一编号
if (st[y] == 2) // 询问的另一点已经遍历过
{
int anc = find(y); // 找到集合的代表点,公共祖先
res[id] = dis[x] + dis[y] - 2 * dis[anc]; // 根据id将答案存储到数组中
}
}
st[x] = 2; // 当前节点遍历完
}
由于树的层数最大有40000层,所以一个节点最多向上跳 l o g 400000 log400000 log400000层,大概是一个大于 2 15 2^{15} 215小于 2 16 2^{16} 216的数,所以最多跳 2 16 − 1 2^{16} - 1 216−1层,二进制位取15即可
#include
#include
using namespace std;
const int N = 4e4 + 10, M = 2 * N;
int h[N], e[M], ne[M], idx;
int q[N], hh, tt = -1;
int depth[N], fa[N][16]; // 最多跳2^16-1
int n;
void add(int x, int y)
{
e[idx] = y, ne[idx] = h[x], h[x] = idx ++ ;
}
void bfs(int root)
{
q[++ tt] = root;
memset(depth, 0x3f, sizeof(depth));
depth[0] = 0, depth[root] = 1;
while (tt >= hh)
{
int x = q[hh ++ ];
for (int i = h[x]; i != -1; i = ne[i])
{
int y = e[i];
if (depth[y] > depth[x] + 1)
{
depth[y] = depth[x] + 1;
fa[y][0] = x;
q[++ tt ] = y;
for (int k = 1; k <= 15; ++ k )
fa[y][k] = fa[fa[y][k - 1]][k - 1];
}
}
}
}
int lca(int x, int y)
{
if (depth[x] < depth[y]) swap(x, y);
for (int k = 15; k >= 0; -- k )
if (depth[fa[x][k]] >= depth[y])
x = fa[x][k];
if (x == y) return y;
for (int k = 15; k >= 0; -- k )
if (fa[x][k] != fa[y][k])
{
x = fa[x][k];
y = fa[y][k];
}
return fa[x][0];
}
int main()
{
memset(h, -1, sizeof(h));
scanf("%d", &n);
int root;
int a, b;
for (int i = 0; i < n; ++ i )
{
scanf("%d%d", &a, &b);
if (b == -1) root = a;
else add(a, b), add(b, a);
}
bfs(root);
int m;
scanf("%d", &m);
for (int i = 0; i < m; ++ i )
{
scanf("%d%d", &a, &b);
int t = lca(a, b);
if (t == a) puts("1");
else if (t == b) puts("2");
else puts("0");
}
return 0;
}
与上题一样,题目要处理很多询问,可以用lca或者离线tarjan解决
树的最短路问题可以从公共祖先的角度考虑,假设x和y的公共祖先为t, d i s ( t ) dis(t) dis(t)为根节点到t的距离,那么x和y之间的最短距离就是 d i s ( x ) + d i s ( y ) − 2 ∗ d i s ( t ) dis(x) + dis(y) - 2 * dis(t) dis(x)+dis(y)−2∗dis(t)
题目没有给定根节点,我们任意指定一个点为根节点即可
lca解法:
#include
#include
using namespace std;
const int N = 1e4 + 10, M = 2 * N;
int h[N], e[M], ne[M], w[M], idx;
int depth[N], fa[N][16];
int q[N], hh, tt = -1;
int dis[N];
void add(int x, int y, int d)
{
e[idx] = y, ne[idx] = h[x], w[idx] = d, h[x] = idx ++ ;
}
void dfs(int root)
{
memset(depth, 0x3f, sizeof(depth));
depth[0] = 0, depth[root] = 1;
q[++ tt ] = root;
while (tt >= hh)
{
int x = q[hh ++ ];
for (int i = h[x]; i != -1; i = ne[i])
{
int y = e[i];
if (depth[y] > depth[x] + 1)
{
depth[y] = depth[x] + 1;
dis[y] = dis[x] + w[i];
fa[y][0] = x;
q[++ tt ] = y;
for (int k = 1; k <= 15; ++ k )
fa[y][k] = fa[fa[y][k - 1]][k - 1];
}
}
}
}
int lca(int x, int y)
{
if (depth[x] < depth[y]) swap(x, y);
for (int k = 15; k >= 0; -- k )
if (depth[fa[x][k]] >= depth[y])
x = fa[x][k];
if (x == y) return y;
for (int k = 15; k >= 0; -- k )
if (fa[x][k] != fa[y][k])
{
x = fa[x][k];
y = fa[y][k];
}
return fa[x][0];
}
int main()
{
memset(h, -1, sizeof(h));
int n, m;
scanf("%d%d", &n, &m);
int x, y, d;
for (int i = 0; i < n - 1; ++ i )
{
scanf("%d%d%d", &x, &y, &d);
add(x, y, d), add(y, x, d);
}
dfs(1);
for (int i = 0; i < m; ++ i )
{
scanf("%d%d", &x, &y);
int t = lca(x, y);
printf("%d\n", dis[x] + dis[y] - 2 * dis[t]);
}
return 0;
}
tarjan解法:
#include
#include
#include
using namespace std;
const int N = 1e4 + 10, M = 2 * N;
typedef pair<int, int> PII;
vector<PII> query[N];
int h[N], e[M], ne[M], w[M], idx;
int res[M], dis[N], st[N], p[N];
int n, m;
void add(int x, int y, int d)
{
e[idx] = y, ne[idx] = h[x], w[idx] = d, h[x] = idx ++ ;
}
void dfs(int x, int fa)
{
for (int i = h[x]; i != -1; i = ne[i])
{
int y = e[i];
if (y != fa)
{
dis[y] = dis[x] + w[i];
dfs(y, x);
}
}
}
int find(int x)
{
if (p[x] != x) p[x] = find(p[x]);
return p[x];
}
void tarjan(int x)
{
st[x] = 1;
for (int i = h[x]; i != -1; i = ne[i])
{
int y = e[i];
if (!st[y])
{
tarjan(y);
p[y] = x;
}
}
for (auto item : query[x])
{
int y = item.first, id = item.second;
if (st[y] == 2)
{
int anc = find(y);
res[id] = dis[x] + dis[y] - 2 * dis[anc];
}
}
st[x] = 2;
}
int main()
{
memset(h, -1, sizeof(h));
scanf("%d%d", &n, &m);
int x, y, d;
for (int i = 0; i < n - 1; ++ i )
{
scanf("%d%d%d", &x, &y, &d);
add(x, y, d), add(y, x, d);
}
for (int i = 0; i < m; ++ i )
{
scanf("%d%d", &x, &y);
query[x].push_back({y, i}); // 将查询的另一个点与查询编号保存
query[y].push_back({x, i});
}
dfs(1, -1);
for (int i = 1; i <= n; ++ i ) p[i] = i;
tarjan(1);
for (int i = 0; i < m; ++ i ) printf("%d\n", res[i]);
return 0;
}
debug:维护的query信息要不同地插入两次
query[x].push_back({y, i});
query[y].push_back({x, i});
因为当前询问有关x的信息时,y可能没有遍历完,但是询问y有关的信息时,x是遍历完的
d 1 ( i , j ) d1(i, j) d1(i,j)表示从i往上跳 2 j 2^j 2j层后,路径上的最大边权
d 2 ( i , j ) d2(i, j) d2(i,j)表示从i往上跳 2 j 2^j 2j层后,路径上的次大边权
跳 2 j 2^j 2j步是一个类似分治的过程,分成两部分跳,这两部分依旧能分成两部分跳
路径上的最大值为每一段的最大值取max,次大值为所有子路径的最大值和次大值中的第二大,每次遍历子线段时维护信息即可
if (d > d1) d2 = d1, d1 = d;
else if (d != d1 && d > d2) d2 = d;
#include
#include
#include
using namespace std;
typedef long long LL;
const int N = 1e5 + 10, M = 3e5 + 10, INF = 0x3f3f3f3f;
struct Edge
{
int x, y, w;
bool f;
bool operator<(const Edge& e) const
{
return w < e.w;
}
}edges[M];
int h[N], e[M], ne[M], w[M], idx;
int p[N], depth[N], fa[N][17];
int q[N], hh, tt = -1;
int d1[N][17], d2[N][17];
int d[M];
int n, m;
void add(int x, int y, int d)
{
e[idx] = y, ne[idx] = h[x], w[idx] = d, h[x] = idx ++ ;
}
int find(int x)
{
if (x != p[x]) p[x] = find(p[x]);
return p[x];
}
LL kruskal()
{
LL res = 0;
sort(edges, edges + m);
for (int i = 1; i <= n; ++ i ) p[i] = i;
for (int i = 0; i < m; ++ i )
{
auto t = edges[i];
int x = t.x, y = t.y, w = t.w;
int px = find(x), py = find(y);
if (px != py)
{
res += w;
p[px] = py;
edges[i].f = true;
add(x, y, w), add(y, x, w);
}
}
return res;
}
void bfs()
{
q[++ tt ] = 1;
memset(depth, 0x3f, sizeof(depth));
depth[0] = 0, depth[1] = 1;
while (tt >= hh)
{
int x = q[hh ++ ];
for (int i = h[x]; i != -1; i = ne[i])
{
int y = e[i];
if (depth[y] > depth[x] + 1)
{
depth[y] = depth[x] + 1;
fa[y][0] = x;
d1[y][0] = w[i], d2[y][0] = -INF;
q[++ tt ] = y;
for (int k = 1; k <= 16; ++ k )
{
int mid = fa[y][k - 1];
fa[y][k] = fa[mid][k - 1];
d1[y][k] = d2[y][k] = -INF;
int a[4] = { d1[mid][k - 1], d2[mid][k - 1], d1[y][k - 1], d2[y][k - 1] };
for (int i = 0; i < 4; ++ i )
{
if (a[i] > d1[y][k]) d2[y][k] = d1[y][k], d1[y][k] = a[i];
else if (a[i] != d1[y][k] && a[i] > d2[y][k]) d2[y][k] = a[i];
}
}
}
}
}
}
int lca(int x, int y, int w)
{
int cnt = 0;
if (depth[x] < depth[y]) swap(x, y);
for (int k = 16; k >= 0; -- k )
if (depth[fa[x][k]] >= depth[y])
{
d[cnt ++ ] = d1[x][k];
d[cnt ++ ] = d2[x][k];
x = fa[x][k];
}
if (x != y)
{
for (int k = 16; k >= 0; -- k )
{
if (fa[x][k] != fa[y][k])
{
d[cnt ++ ] = d1[x][k], d[cnt ++ ] = d2[x][k];
d[cnt ++ ] = d1[y][k], d[cnt ++ ] = d2[y][k];
x = fa[x][k], y = fa[y][k];
}
d[cnt ++ ] = d1[x][0], d[cnt ++ ] = d2[x][0];
d[cnt ++ ] = d1[y][0], d[cnt ++ ] = d2[y][0];
}
}
int dmax1 = -INF, dmax2 = -INF;
for (int i = 0; i < cnt; ++ i )
{
if (d[i] > dmax1) dmax2 = dmax1, dmax1 = d[i];
else if (d[i] != dmax1 && d[i] > dmax2) dmax2 = d[i];
}
if (w > dmax1) return w - dmax1;
return w - dmax2;
}
int main()
{
memset(h, -1, sizeof(h));
scanf("%d%d", &n, &m);
for (int i = 0; i < m; ++ i )
scanf("%d%d%d", &edges[i].x, &edges[i].y, &edges[i].w);
LL sum = kruskal(); // 最小生成树的权值
bfs();
LL res = 1e19;
for (int i = 0; i < m; ++ i )
if (!edges[i].f)
res = min(res, sum + lca(edges[i].x, edges[i].y, edges[i].w));
printf("%lld\n", res);
return 0;
}
树上差分,将x到y的最短路径上所有的边加上c,若p为x和y的公共祖先,那么
d(x) += c, d(y) += c, d(p) -= 2c
如何计算某条边的权值?以这条边的子节点为根的子树中,所有边的权值相加为这条边的权值
在一颗树中,删除任意一条边,那么这颗树将被切成两个连通块
若向树中再添加一条边,那么这条非树边和树边就一定构成环,要向将此时的“树”切成两个连通块,就要删除环中的任意一条树边与这条非树边
题目限制只能先切树边,再切非树边,一共两次,两次过后还没有切成两个连通块,说明这个方案行不通
当切除树边,不用再切除非树边就得到两个连通块时,由于题目限制,还需要切除一条非树边,假设非树边有m条,那么此时可以选择m条边中的任意一条切除,此时的方案数为m
若切除树边后,还要再切除一条非树边,才能得到两个连通块时,此时的方案数为1,只能切除这条环中的非树边
若切除树边后,还要再切除大于一条的非树边,此时无法再切除,方案数为0
现在的问题是,如何知道切除某条树边后,还需要再切除几条非树边?
假设现在已经用树边建立了一棵树,此时再添加非树边将构成环,将环中的所有树边权值加1,假设初始权值为0,此时可以使用树上差分
然后遍历所有的树边,根据边权计算方案数
#include
#include
using namespace std;
const int N = 1e5 + 10, M = 2e5 + 10;
int h[N], e[M], ne[M], idx;
int depth[N], fa[N][17];
int q[N], hh, tt = -1;
int d[N];
int n, m, ans;
void add(int x ,int y)
{
e[idx] = y, ne[idx] = h[x], h[x] = idx ++ ;
}
void bfs()
{
memset(depth, 0x3f, sizeof(depth));
depth[0] = 0, depth[1] = 1;
q[++ tt ] = 1;
while (tt >= hh)
{
int x = q[hh ++ ];
for (int i = h[x]; i != -1; i = ne[i])
{
int y = e[i];
if (depth[y] > depth[x] + 1)
{
depth[y] = depth[x] + 1;
fa[y][0] = x;
q[++ tt ] = y;
for (int k = 1; k <= 16; ++ k )
fa[y][k] = fa[fa[y][k - 1]][k - 1];
}
}
}
}
int lca(int x, int y)
{
if (depth[x] < depth[y]) swap(x, y);
for (int k = 16; k >= 0; -- k )
if (depth[fa[x][k]] >= depth[y])
x = fa[x][k];
if (x == y) return x;
for (int k = 16; k >= 0; -- k )
if (fa[x][k] != fa[y][k])
{
x = fa[x][k];
y = fa[y][k];
}
return fa[x][0];
}
int dfs(int x, int f)
{
int res = d[x];
for (int i = h[x]; i != -1; i = ne[i])
{
int y = e[i];
if (f != y)
{
int t = dfs(y, x);
if (t == 0) ans += m;
else if (t == 1) ans ++;
res += t;
}
}
return res;
}
int main()
{
memset(h, -1, sizeof(h));
scanf("%d%d", &n, &m);
int x, y;
for (int i = 0; i < n - 1; ++ i )
{
scanf("%d%d", &x, &y);
add(x, y), add(y, x);
}
bfs();
int res = 0;
for (int i = 0; i < m; ++ i)
{
scanf("%d%d", &x, &y);
int p = lca(x, y);
d[x] ++, d[y] ++, d[p] -= 2;
}
dfs(1, -1);
printf("%d\n", ans);
return 0;
}