突然发现剖分树可以在log(n)的时间里求出lca,于是又删了几十行的代码。
#include <iostream> #include <stdio.h> #include <vector> #include <algorithm> #include <ctime> using namespace std; const int N = 20005; const int M = 100005; const int QN = M; const int INF = 0X7FFFFFFF; typedef int vType; typedef pair<int, int> pii; #define mkpii make_pair<int, int> struct e{ int v; e* nxt; }es[N<<1], *fir[N]; struct node{ int ls, rs; //左右儿子的下标,为-1表示空 int l, r; //区间的左右标号 //数据域 int id; //如果这个是叶子节点,id表示代表原树中的节点的标号 vType Min; //Min为这一整段插入的一个最小值 int mid() { return (l + r) >> 1; } }nodes[N<<1]; struct se{ pii e; int len; }ses[M<<1], lea[M<<1]; int n, en, qn, m; vector<pii> qlca[N]; vector<se> nes[N]; int par[N], fa[N]; //par[i]为i的直接前驱, fa用于并查集; int ln, cnt; //ln为链的数目,cnt为剖分树中节点的数目 int leaNum; int sons[N], que[N], dep[N], id[N], st[N], ed[N], root[N], top[N], sNum[N]; //sons[i]表示i为根的子树的大小,dep[i]表示节点的i的深度,id[i]为i所在链的标号,st和ed记录每条链的左右标号,root记录每条链的根节点的下标 //top[i]为第i条链的顶部节点,sNum[i]表示i的直接后继的个数 int ith[N], pMin[N], seg[N]; //ith[i]表示节点i是其父节点的第ith[i]个儿子(按访问顺序), //seg在链上构建线段树的时候使用 vType iw[N]; //iw[i]表示节点i在最小生成树中与其他节点之间的边的权值的总和 int tr; //最小生成树的根节点 inline void add_e(int u, int v){ es[en].v = v; es[en].nxt = fir[u]; fir[u] = &es[en++]; } inline void newNode(int& id, int l, int r){ nodes[cnt].ls = nodes[cnt].rs = -1; nodes[cnt].l = l; nodes[cnt].r = r; nodes[cnt].Min = INF; id = cnt++; } void build(int& id, int l, int r){ //在剖分出来的链上构建线段树 newNode(id, l, r); if(l >= r){ nodes[id].id = seg[l]; return ; } int mid = (l+r)>>1; build(nodes[id].ls, l, mid); build(nodes[id].rs, mid+1, r); } void initTree(){ //初始化剖分树 //确定父亲 int l, r, u, v, i; e* cur; l = r = 0; que[r++] = tr; par[tr] = -1; dep[tr] = 0; while(l != r){ u = que[l++]; int g = 1; for(cur = fir[u]; cur; cur = cur->nxt){ if((v = cur->v) != par[u]){ que[r++] = v; par[v] = u; dep[v] = dep[u]+1; ith[v] = g++; } } } //计算子树大小 for(i = 1; i <= n; i++){ sons[i] = 1; sNum[i] = 0; id[i] = -1; } for(i = r-1; i >= 0; i--){ u = que[i]; if(par[u] >= 0){ sons[par[u]] += sons[u]; sNum[par[u]]++; } } //剖分链 l = r = 0; que[r++] = tr; ln = cnt = 0; while(l != r){ u = que[l++]; st[ln] = dep[u]; //用节点的深度作为线段树中区间的左右标号 top[ln] = u; while(u >= 0){ id[u] = ln; ed[ln] = dep[u]; seg[dep[u]] = u; int best; for(cur = fir[u], best=-1; cur; cur = cur->nxt){ if(id[v = cur->v] == -1){ if(best == -1 || (best >= 0 && sons[v] > sons[best])){ best = v; } } } if(best >= 0){ for(cur = fir[u]; cur; cur = cur->nxt){ if(id[v = cur->v] == -1 && best != v){ que[r++] = v; } } } u = best; } root[ln] = -1; build(root[ln], st[ln], ed[ln]); ln++; } } int qrylKthFar(int& id, int i, int k){ //在链上查询i的第k个父节点(第0个为自己) if(nodes[id].l == nodes[id].r) return nodes[id].id; int mid = nodes[id].mid(); if(i - mid - 1 >= k) return qrylKthFar(nodes[id].rs, i, k); else return qrylKthFar(nodes[id].ls, i, k); } int qryKthFar(int i, int k){ //查询i的第k个父节点(第0个为自己) int u = i, ri; while(true){ ri = id[u]; if(dep[u] - st[ri] >= k){ return qrylKthFar(root[ri], dep[u], k); }else{ k -= dep[u] - st[ri] + 1; u = par[top[ri]]; } } } void inslMin(int& id, int ql, int qr, int mv){ if(id == -1) return ; if(ql <= nodes[id].l && nodes[id].r <= qr){ if(nodes[id].Min > mv){ nodes[id].Min = mv; } return; } if(nodes[id].l == nodes[id].r) return; int mid = nodes[id].mid(); if(ql <= mid){ inslMin(nodes[id].ls, ql, qr, mv); } if(qr > mid){ inslMin(nodes[id].rs, ql, qr, mv); } } void insMin(int i, int k, vType mv){ //在节点i和i的第k个父节点之间插入mv int b, u; u = i; while(true){ b = id[u]; if(dep[u]-st[b] >= k){ inslMin(root[b], dep[u]-k, dep[u], mv); return; }else{ inslMin(root[b], st[b], dep[u], mv); k -= dep[u] - st[b] + 1; u = par[top[b]]; } } } bool input(){ scanf("%d%d", &n, &m); int i, k, tn; for(i = tn = 0; i < m; i++){ scanf("%d%d%d%d", &ses[i].e.first, &ses[i].e.second, &ses[i].len, &k); if(k == 1){ //既然这条边还在使用,可以把它的边权设为0 ses[i].len = 0; } if(ses[i].e.first != ses[i].e.second){ tn++; } } m = tn; return true; } inline bool cmp(se a, se b){ return a.len < b.len; } int findFa(int u){ int k = u; while(k != fa[k]) k = fa[k]; while(u != k){ int tf = fa[u]; fa[u] = k; u = tf; } return k; } void merge(int u, int v){ int fu, fv; fu = findFa(u); fv = findFa(v); fa[fu] = fv; } int kruskal(int n, int m, int& leaNum, bool flag){ //flag为true表示需要构图 int i, ans, k, u, v; for(i = 1; i <= n; i++){ fa[i] = i; } if(flag){ for(i = 1; i <= n; i++){ iw[i] = 0; fir[i] = NULL; } en = leaNum = 0; } sort(ses, ses + m, cmp); for(i = ans = 0, k = 1; k < n && i < m; i++){ u = ses[i].e.first; v = ses[i].e.second; if(findFa(u) != findFa(v)){ ans += ses[i].len; k++; merge(u, v); if(flag){ add_e(u, v); add_e(v, u); iw[u] += ses[i].len; iw[v] += ses[i].len; } }else if(flag){ //这条边被剩出来 lea[leaNum++] = ses[i]; } } if (flag) { for (; i < m; i++) { lea[leaNum++] = ses[i]; } } if(k < n) ans = INF; return ans; } void handlelca(int u, int v, int anc, int len){ if(u != anc && v != anc){ int ku, kv; ku = qryKthFar(u, dep[u] - dep[anc] - 1); kv = qryKthFar(v, dep[v] - dep[anc] - 1); se te; te.e.first = ith[ku]; te.e.second = ith[kv]; te.len = len; nes[anc].push_back(te); } if(dep[anc] + 2 <= dep[u]){ insMin(u, dep[u] - dep[anc] - 2, len); } if(dep[anc] + 2 <= dep[v]){ insMin(v, dep[v] - dep[anc] - 2, len); } } //qn为查询lca的次数,qs记录查询lca的两个几点,anc记录每次查询的结果 int getlca(int u, int v){ while(id[u] != id[v]){ if(id[u] < id[v]) swap(u, v); u = par[top[id[u]]]; } if(dep[u] < dep[v]) swap(u, v); return v; } void lca(se* qs, int qn){ int i; for(i = 1; i <= n; i++){ nes[i].clear(); } for(i = 0; i < qn; i++){ int u, v, anc; u = qs[i].e.first; v = qs[i].e.second; anc = getlca(u, v); handlelca(v, u, anc, qs[i].len); } } void getpMin(int& id, int mv){ if(mv > nodes[id].Min){ mv = nodes[id].Min; } if(nodes[id].l == nodes[id].r){ pMin[nodes[id].id] = mv; return; } getpMin(nodes[id].ls, mv); getpMin(nodes[id].rs, mv); } void getpMin(){ int i; for(i = 0; i < ln; i++){ getpMin(root[i], INF); } } void solve(){ tr = 1; //设置根节点 int sum, i, sn, v, num; e* cur; sum = kruskal(n, m, leaNum, true); initTree(); lca(lea, leaNum); getpMin(); for(i = 1; i <= n; i++){ num = 0; sn = sNum[i]; if (par[i] >= 1) { sn++; for (cur = fir[i]; cur; cur = cur->nxt) { if ((v = cur->v) != par[i] && pMin[v] < INF) { ses[num].e.first = sn; ses[num].e.second = ith[v]; ses[num].len = pMin[v]; num++; } } } int size = nes[i].size(), j; for(j = 0; j < size; j++){ ses[num++] = nes[i][j]; } int ans = kruskal(sn, num, leaNum, false); if(ans < INF){ ans += sum - iw[i]; printf("%d\n", ans); }else{ printf("inf\n"); } } } int main() { int t; scanf("%d", &t); while(t--){ input(); solve(); } return 0; }