题意翻译
给定一棵n个点的树,树上有m个黑点,求出一条路径,使得这条路径经过的黑点数小于等于k,且路径长度最大
Code:
#include
using namespace std;
#define pr pair
#define mp make_pair
const int maxn = 2000003;
const int inf = 1000000000;
void setIO(string a)
{
string in = a + ".in", out = a + ".out";
freopen(in.c_str(), "r", stdin);
}
vector vrr;
int edges, n, m, K, ans, mdep, tl, root = 0, sn;
int hd[maxn], to[maxn << 1], nex[maxn << 1], val[maxn << 1], siz[maxn];
int mk[maxn], vis[maxn], mine[maxn], dis[maxn], tmp1[maxn], tmp2[maxn], f[maxn];
void add(int u, int v, int c)
{
nex[++edges] = hd[u], hd[u] = edges, to[edges] = v, val[edges] = c;
}
void Getroot(int u, int ff)
{
siz[u] = 1, f[u] = 0;
for(int i = hd[u]; i ; i = nex[i])
{
int v = to[i];
if(v == ff || vis[v]) continue;
Getroot(v, u);
siz[u] += siz[v], f[u] = max(f[u], siz[v]);
}
f[u] = max(sn - siz[u], f[u]);
if(f[u] < f[root]) root = u;
}
void getmax(int x, int ff, int d1)
{
siz[x]=1;
mdep = max(mdep, d1 + mk[x]);
for(int i = hd[x]; i ; i = nex[i])
{
int v = to[i];
if(vis[v] || v == ff) continue;
dis[v] = dis[x] + val[i];
getmax(v, x, d1 + mk[x]);
siz[x]+=siz[v];
}
}
void getdis(int x, int ff, int d2)
{
if(d2 > K) return ;
tmp1[++tl] = dis[x], tmp2[tl] = d2 + mk[x];
for(int i = hd[x]; i ; i = nex[i])
{
int v = to[i];
if(v == ff || vis[v]) continue;
getdis(v, x, d2 + mk[x]);
}
}
void calc(int x)
{
if(mk[x]) --K;
vrr.clear();
siz[x]=1;
for(int i = hd[x]; i ; i = nex[i])
{
int v = to[i];
if(vis[v]) continue;
mdep = 0, dis[v] = val[i];
getmax(v, x, 0);
siz[x]+=siz[v];
vrr.push_back(mp(mdep, v));
}
sort(vrr.begin(), vrr.end());
tl = mine[0] = 0;
// printf("%d ::: ",x);
for(int sz = vrr.size(), i = 0; i < sz; ++i)
{
int cur = vrr[i].second;
int pdl = tl;
getdis(cur, x, 0);
for(int j = pdl + 1; j <= tl; ++j)
if(K >= tmp2[j])
ans = max(ans, mine[K - tmp2[j]] + tmp1[j]);
// printf("%d ",vrr[i].second);
// tmp2[j] :: 有tmp2[j]个黑子时的最大值.
for(int j = pdl + 1; j <= tl; ++j) mine[tmp2[j]] = max(mine[tmp2[j]], tmp1[j]);
for(int j = 1; j <= vrr[i].first; ++j) mine[j] = max(mine[j-1],mine[j]);
}
// printf("\n");
if(vrr.size())
for(int i=1;i<=vrr[vrr.size()-1].first;++i) mine[i] = -inf;
if(mk[x]) ++K;
}
void solve(int u)
{
vis[u] = 1;
calc(u);
for(int i=hd[u];i;i=nex[i])
{
int v=to[i];
if(vis[v]) continue;
root=0,sn=siz[v];
Getroot(v,u);
solve(root);
}
}
int main()
{
// setIO("input");
scanf("%d%d%d",&n, &K, &m);
for(int i = 1,o; i <= m ; ++i)
{
scanf("%d",&o);
mk[o] = 1;
}
for(int i = 1; i < n ; ++i)
{
int u, v, c;
scanf("%d%d%d",&u, &v, &c);
add(u, v, c), add(v, u, c);
}
for(int i=1;i