思路:
一棵树上最长链处理,分出两种情况,一种是(a,b)各占一个端点,还有一种情况a占整条链,b是全踩在非最长链。
#include #include #include #include #include #include #include #include #include #include <string> #include #include #include #include #include #include #include <set> #include #include // #include // using namespace __gnu_pbds; using namespace std; #define pb push_back #define fi first #define se second #define debug(x) cerr<<#x << " := " << x << endl; #define bug cerr<<"-----------------------"<#define FOR(a, b, c) for(int a = b; a <= c; ++ a) typedef long long ll; typedef unsigned long long ull; typedef long double ld; typedef pair<int, int> pii; typedef pair pll; const int inf = 0x3f3f3f3f; const ll inff = 0x3f3f3f3f3f3f3f3f; const int mod = 998244353; template inline T read(T&x){ x=0;int f=0;char ch=getchar(); while (ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar(); while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar(); return x=f?-x:x; } /**********showtime************/ const int maxn = 1e5+9; int n; vector<int>mp[maxn]; int vis[maxn]; int dis[maxn]; vector<int>lian; ///扣出最长链 void koulian() { for(int i=1; i<=n; i++) dis[i] = inf; dis[1] = 0; queue<int>que; que.push(1); int t = 1; while(!que.empty()) { int u = que.front(); que.pop(); if(dis[u] > dis[t])t = u; for(int v : mp[u]) { if(dis[v] > dis[u] + 1) { dis[v] = dis[u] + 1; que.push(v); } } } for(int i=1; i<=n; i++) dis[i] = inf; dis[t] = 0; que.push(t); int s = t; while(!que.empty()) { int u = que.front(); que.pop(); if(dis[u] > dis[s])s = u; for(int v:mp[u]) { if(dis[v] > dis[u] + 1) { dis[v] = dis[u] + 1; que.push(v); } } } lian.pb(s); vis[s] = 1; while(s != t) { for(int v : mp[s]) { if(dis[v] +1 == dis[s]) { s = v; lian.pb(s); vis[s] = 1; break; } } } } int dpa[maxn], dpb[maxn][2], pre[maxn]; int dppre[maxn], dpback[maxn]; ///求出以最长链上一个点为根节点的不经过最长链的最大深度 void dfs1(int u, int fa) { dpa[u] = 1; for(int v : mp[u]) { if(v == fa || vis[v]) continue; dfs1(v, u); dpa[u] = max(dpa[u], dpa[v] + 1); }} void dfs2(int u, int fa) { dpb[u][0] = dpb[u][1] = 1; ///dpb[0]表示包含根节点的最长链 ///dpb[1]表示包含根节点的次长链 pre[u] = 1; for(int v : mp[u]) { if(vis[v] || v == fa) continue; dfs2(v, u); pre[u] = max(pre[u], pre[v]); if(dpb[u][1] <= dpb[v][0] + 1){ dpb[u][1] = dpb[v][0] + 1; if(dpb[u][0] < dpb[u][1]) { swap(dpb[u][0], dpb[u][1]); } } } pre[u] = max(pre[u], dpb[u][0] + dpb[u][1] - 1); } int hei[maxn]; int main(){ int T; scanf("%d", &T); while(T--){ scanf("%d", &n); for(int i=1; i) { int u, v; scanf("%d%d", &u, &v); mp[u].pb(v); mp[v].pb(u); } for(int i=0; i<=n; i++) vis[i] = 0, hei[i] = 0, pre[i] = 0, dppre[i] = 0,dpback[i] = 0; koulian(); for(int i=0; i) { int v = lian[i]; dfs1(v, v); if(i)dppre[i] = max(dppre[i-1], dpa[v] + i); else dppre[i] = dpa[v]; for(int p : mp[v]) { if(vis[p]) continue; dfs2(p, p); pre[v] = max(pre[v], pre[p]); } pre[v] = max(pre[v], pre[lian[max(0, i-1)]]); } int cc = 0; for(int i=lian.size()-1; i>=0; i--) { if(i == lian.size() - 1) dpback[i] = dpa[lian[i]]; else dpback[i] = max(dpback[i+1], dpa[lian[i]] + cc); cc++; } int all = lian.size(); hei[all] = pre[lian[all-1]]; hei[pre[lian[all-1]]] = all; for(int i=lian.size() - 1; i>=1; i--) { int v = lian[i]; int a = dppre[i-1]; int b = dpback[i]; hei[a] = max(hei[a], b); hei[b] = max(hei[b], a); } ll sum = 0; int c = 0; for(int i=all; i>=1; i--) { c = max(c, hei[i]); sum = sum + c; } printf("%lld\n", sum); lian.clear(); for(int i=1; i<=n; i++) mp[i].clear(); } return 0; } /* 10 9 1 2 2 3 3 4 4 5 5 8 3 6 3 7 7 9 14 1 2 2 3 3 4 4 5 5 6 6 7 3 8 3 9 4 10 4 11 11 14 5 12 5 13 = 36 */
就是字典树+贪心,和第五场那个贪心顺序反一下就行了
// #pragma GCC optimize(2) // #pragma GCC optimize(3) // #pragma GCC optimize(4) #include #include #include #include #include #include #include #include #include #include <string> #include #include #include #include #include #include #include <set> #include // #include // using namespace __gnu_pbds; using namespace std; #define pb push_back #define fi first #define se second #define debug(x) cerr<<#x << " := " << x << endl; #define bug cerr<<"-----------------------"<#define FOR(a, b, c) for(int a = b; a <= c; ++ a) typedef long long ll; typedef long double ld; typedef pair<int, int> pii; typedef pair pll; const int inf = 0x3f3f3f3f; const ll inff = 0x3f3f3f3f3f3f3f3f; const int mod = 998244353; template inline T read(T&x){ x=0;int f=0;char ch=getchar(); while (ch<'0'||ch>'9') f|=(ch=='-'),ch=getchar(); while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar(); return x=f?-x:x; } /**********showtime************/ const int maxn = 1e5+9; int a[maxn],b[maxn]; int tot[2],rt[2]; int bz[33]; struct node{ int ch[2]; int fa; int sz; void init(int f) { ch[0] = ch[1] = 0; fa = f; sz = 0; } }tree[2][maxn * 30]; int shu[35]; void add(int p, int len, int flag) { if(len == 0){ tree[flag][p].sz++; return; } if(tree[flag][p].ch[shu[len]] == 0) { tree[flag][p].ch[shu[len]] = ++ tot[flag]; tree[flag][tot[flag]].init(p); } int nx = tree[flag][p].ch[shu[len]]; add(nx, len-1, flag); int lc = tree[flag][p].ch[0]; int rc = tree[flag][p].ch[1]; tree[flag][p].sz = tree[flag][lc].sz + tree[flag][rc].sz; } void insert(int val, int flag) { int len = 0; for(int i=0; i<=30; i++) shu[++len] = val % 2, val /= 2; add(rt[flag], 30, flag); } void display(int rt, int flag) { if(rt == 0) return ; // cout< display(tree[flag][rt].ch[0], flag); display(tree[flag][rt].ch[1], flag); } vector<int>vec; void find(int a, int b, int cen, int val) { if(cen == 0) { vec.pb(val); tree[0][a].sz--; tree[1][b].sz--; return; } if(tree[0][tree[0][a].ch[0]].sz && tree[1][ tree[1][b].ch[1]].sz){ find(tree[0][a].ch[0], tree[1][b].ch[1], cen-1, val + bz[cen-1]); } else if(tree[0][tree[0][a].ch[1]].sz && tree[1][ tree[1][b].ch[0]].sz){ find(tree[0][a].ch[1], tree[1][b].ch[0], cen-1, val + bz[cen-1]); } else if(tree[0][ tree[0][a].ch[0] ].sz && tree[1][ tree[1][b].ch[0]].sz ) { find(tree[0][a].ch[0], tree[1][b].ch[0], cen-1, val); } else if(tree[0][ tree[0][a].ch[1] ].sz && tree[1][ tree[1][b].ch[1]].sz) { find(tree[0][a].ch[1], tree[1][b].ch[1], cen-1, val); } tree[0][a].sz = tree[0][tree[0][a].ch[0]].sz + tree[0][tree[0][a].ch[1]].sz; tree[1][b].sz = tree[1][tree[1][b].ch[0]].sz + tree[1][tree[1][b].ch[1]].sz; } int main(){ int T; scanf("%d", &T); bz[0] = 1; for(int i=1; i<=30; i++) bz[i] = 2 * bz[i-1]; while(T--) { tot[0] = tot[1] = 0; rt[0] = ++tot[0]; tree[0][rt[0]].init(0); rt[1] = ++tot[1]; tree[1][rt[1]].init(0); int n; scanf("%d", &n); for(int i=1; i<=n; i++) scanf("%d", &a[i]), insert(a[i], 0); for(int i=1; i<=n; i++) scanf("%d", &b[i]), insert(b[i], 1); vec.clear(); for(int i=1; i<=n; i++) { find(rt[0], rt[1], 30, 0); } ll sum = 0; for(int i=0; i vec[i]; printf("%lld\n", sum); } return 0; }