连接后的新树的直径要么是原树的直径。。。要么是连接的两个点各自的最长链加1.。。。先预处理每个点的最长链,然后存储一颗树的最长链,另外一颗树遍历每个节点,二分+前缀和求出答案。。。
#include <iostream> #include <queue> #include <stack> #include <map> #include <set> #include <bitset> #include <cstdio> #include <algorithm> #include <cstring> #include <climits> #include <cstdlib> #include <cmath> #include <time.h> #define maxn 100005 #define maxm 500005 #define eps 1e-7 #define mod 1000000007 #define INF 0x3f3f3f3f #define PI (acos(-1.0)) #define lowbit(x) (x&(-x)) #define mp make_pair #define ls o<<1 #define rs o<<1 | 1 #define lson o<<1, L, mid #define rson o<<1 | 1, mid+1, R #define pii pair<int, int> #pragma comment(linker, "/STACK:16777216") typedef long long LL; typedef unsigned long long ULL; //typedef int LL; using namespace std; LL qpow(LL a, LL b){LL res=1,base=a;while(b){if(b%2)res=res*base;base=base*base;b/=2;}return res;} LL powmod(LL a, LL b){LL res=1,base=a;while(b){if(b%2)res=res*base%mod;base=base*base%mod;b/=2;}return res;} // head struct Edge { int v, next; Edge(int v = 0, int next = 0) : v(v), next(next) {} }E[maxm]; queue<int> q; int HA[maxn]; int HB[maxn]; int dis[maxn]; int res1[maxn]; int res2[maxn]; int son[maxn]; LL sum[maxn]; int n, m, cntE; void addedges(int u, int v, int H[]) { E[cntE] = Edge(v, H[u]); H[u] = cntE++; } void init() { cntE = 0; memset(HA, -1, sizeof HA); memset(HB, -1, sizeof HB); memset(res1, 0, sizeof res1); memset(res2, 0, sizeof res2); } int bfs(int u, int H[], int flag) { memset(dis, -1, sizeof dis); q.push(u), dis[u] = 0; int res = 0, t = -1; while(!q.empty()) { u = q.front(); q.pop(); if(!flag) res1[u] = max(res1[u], dis[u]); else res2[u] = max(res2[u], dis[u]); if(dis[u] > t) t = dis[u], res = u; for(int e = H[u]; ~e; e = E[e].next) { int v = E[e].v; if(dis[v] == -1) { dis[v] = dis[u] + 1; q.push(v); } } } return res; } void work() { int u, v; for(int i = 1; i < n; i++) { scanf("%d%d", &u, &v); addedges(u, v, HA); addedges(v, u, HA); } for(int i = 1; i < m; i++) { scanf("%d%d", &u, &v); addedges(u, v, HB); addedges(v, u, HB); } LL tt = 0; u = bfs(1, HA, 0); u = bfs(u, HA, 0); tt = max(tt, (LL)dis[u]); bfs(u, HA, 0); u = bfs(1, HB, 1); u = bfs(u, HB, 1); tt = max(tt, (LL)dis[u]); bfs(u, HB, 1); sort(res1+1, res1+n+1); sort(res2+1, res2+m+1); sum[m+1] = 0; for(int i = m; i >= 1; i--) sum[i] = sum[i+1] + res2[i]; LL ans = 0; for(int i = 1; i <= n; i++) { int t = lower_bound(res2+1, res2+m+1, tt - res1[i] - 1) - res2; ans += sum[t] + (LL)(m - t + 1) * (1 + res1[i]); ans += tt * (t - 1); } printf("%lld\n", ans); } int main() { while(scanf("%d%d", &n, &m) != EOF) { init(); work(); } return 0; }