题意:给定一棵树,求出树上的一点,使得树上的所有点到该点的距离之和最小。
思路:暴力显然是O(N^2)等死对吧。
我们首先将无根树转化为有根树,然后一边dfs求出f[i],size[i].
f[i]表示以i为根的子树中所有的点到i的距离之和,size[i]表示以i为根的子树的点数。
下面开始脑洞大开:
现在对于我们一开始的那个root,我们已经知道了答案。问题就是如何快速的推知别的点作为根时的答案。
我们重新进行一次dfs,当找到x时,我们用dp[fa[x]]+padis[x]*size[fa[x]]更新答案。
我们记录一下当前的dp[x],以及size[x].
每找到一个儿子son,向下dfs时,我们令dp[x]=dp[fa[x]]+size[fa[x]]*padis[x]+dp[x]-dp[son]-size[son]*padis[son],size[x]=size[fa[x]]+size[x]-size[son],然后再向下dfs.
不要问我为什么。。。
我的代码用的是更加脑洞大开的方法。。。
Code:
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> using namespace std; #define N 100010 int head[N], next[N << 1], end[N << 1], len[N << 1]; void addedge(int a, int b, int _len) { static int q = 1; len[q] = _len; end[q] = b; next[q] = head[a]; head[a] = q++; } int num[N]; long long dp[N]; int pa[N], padis[N], size[N]; void dfs(int x, int fa) { size[x] = num[x]; for(int j = head[x]; j; j = next[j]) if (end[j] != fa) pa[end[j]] = x, padis[end[j]] = len[j], dfs(end[j], x); for(int j = head[x]; j; j = next[j]) if (end[j] != fa) dp[x] += dp[end[j]] + (long long)size[end[j]] * len[j], size[x] += size[end[j]]; } long long res = 1LL << 60; int presize[N], sufsize[N], addsize[N], sav[N], top; long long pre[N], suf[N], add[N]; void work(int x) { long long ans = add[x] + (long long)addsize[x] * padis[x] + dp[x]; if (ans < res) res = ans; register int i, j; top = 0; for(j = head[x]; j; j = next[j]) if (end[j] != pa[x]) sav[++top] = end[j]; presize[0] = pre[0] = 0, sufsize[top + 1] = suf[top + 1] = 0; for(i = 1; i <= top; ++i) presize[i] = presize[i - 1] + size[sav[i]], pre[i] = pre[i - 1] + dp[sav[i]] + (long long)size[sav[i]] * padis[sav[i]]; for(i = top; i >= 1; --i) sufsize[i] = sufsize[i + 1] + size[sav[i]], suf[i] = suf[i + 1] + dp[sav[i]] + (long long)size[sav[i]] * padis[sav[i]]; for(i = 1; i <= top; ++i) { addsize[sav[i]] = addsize[x] + num[x] + presize[i - 1] + sufsize[i + 1]; add[sav[i]] = add[x] + (long long)addsize[x] * padis[x] + pre[i - 1] + suf[i + 1]; } for(j = head[x]; j; j = next[j]) if (end[j] != pa[x]) work(end[j]); } int main() { int n; scanf("%d", &n); register int i, j; for(i = 1; i <= n; ++i) scanf("%d", &num[i]); int a, b, x; for(i = 1; i < n; ++i) { scanf("%d%d%d", &a, &b, &x); addedge(a, b, x); addedge(b, a, x); } dfs(1, -1); work(1); printf("%lld", res); return 0; }