JZOJ6436. 【GDOI2020模拟01.16】zsy家今天的饭

许久未打比赛,果然一打就炸,体验极差。只改出了一道所谓的签到题。

题目大意

有一棵n个点的树,其中有m个关键点。在这m个关键点中随机选择k个点,从任意一点开始到任意一点结束的经过所有的k个点的最短路程期望是多少。 1 ≤ n ≤ 1 e 5 , k ≤ m ≤ 500 1 ≤ n ≤ 1e5, k ≤ m ≤ 500 1n1e5,km500

分析

这道题需要用到虚数的思路,但并不需要真的将虚树建出来。运用到虚树的思路就是先将所有关键点提出来建一棵虚树,那么答案就是 2 S u m − L 2Sum -L 2SumL,即虚树的边权和的两倍减去直径,此处钦定若有多条直径,取编号最小的直径。那么我们不妨拆成两个部分分辨算期望:

  • 第一部分:对于任意一条边出现的概率一定是
    ( ( m k ) − ( x k ) − ( y k ) ) / ( m k ) (\tbinom{m}{k} - \tbinom{x}{k} - \tbinom{y}{k})/\tbinom{m}{k} ((km)(kx)(ky))/(km)其中 x , y x,y x,y分别是这条边以上和这条边一下的餐厅的个数,那么这个概率很显然,因为只有该边的两边有餐厅出现该边就一定会出现。
  • 第二部分:对于任意两个点 u , v u,v u,v,考虑有多少个点能出现,枚举每个点 x x x判断 d i s ( u , x ) dis(u, x) dis(u,x) d i s ( v , x ) dis(v, x) dis(v,x) d i s ( u , v ) dis(u, v) dis(u,v)的关系以及相等时判断标号大小即可。
    那么这道题就不是很难了,只用先预处理出任意两点间的距离,以及任意边两侧的餐厅的个数。

Code

#include 
#include 
#define ll long long
#define C(a,b) (G[a][b])
using namespace std;

const int N = 1e5 + 10;
const int M = 1e3 + 10;
const ll mo = 998244353;
struct Edge{
	int to,next;
	ll val;
} f[N << 1];
int n,m,k,a[N],cnt,head[N],fa[N][25],dep[N],num[N];
bool tag[N];
ll ans,dis[N],G[M][M],dist[M][M];
double an;

int read()
{
	int x = 0;
	char ch = getchar();
	while (ch < '0' || ch > '9') ch = getchar();
	while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0'; ch = getchar();}
	return x;
}

void add(int u,int v,int w)
{
	f[++ cnt].to = v;
	f[cnt].next = head[u];
	f[cnt].val = (ll)w;
	head[u] = cnt;
}

ll ksm(ll a,int b)
{
	ll res = 1ll;
	while (b)
	{
		if (b & 1)  res = res * a % mo;
		b >>= 1;
		a = a * a % mo;
	}
	return res;
}

void dfs(int u)
{
	if (tag[u] == 1) num[u] = 1;
	for (int i = head[u]; i; i = f[i].next)
	{
		if (f[i].to == fa[u][0]) continue;
		fa[f[i].to][0] = u;
		dep[f[i].to] = dep[u] + 1;
		dis[f[i].to] = dis[u] + f[i].val;
		dfs(f[i].to);
		num[u] += num[f[i].to];
		ll tmp = (C(num[f[i].to],k) + C(m - num[f[i].to],k)) % mo;
		tmp = (C(m,k) - tmp) % mo;
		ll p = ans;
		ans = (ans + (ll)(f[i].val % mo * tmp % mo)) % mo;
	}
}

int lca(int u,int v)
{
	if (dep[u] < dep[v]) swap(u,v);
	for (int i = 20; i >= 0; i --) 
		if (dep[fa[u][i]] >= dep[v]) u = fa[u][i];
	if (u == v) return u;
	for (int i = 20; i >= 0; i --)
		if (fa[u][i] != fa[v][i]) u = fa[u][i],v = fa[v][i];
	return fa[v][0];
}

ll get(int u,int v) {return dis[u] + dis[v] - 2 * dis[lca(u,v)];}

int main()
{
	freopen("meal.in","r",stdin);
	freopen("meal.out","w",stdout);
	n = read(),m = read(),k = read();
	for (int i = 1; i <= m; i ++) tag[a[i] = read()] = 1;
	for (int i = 1,u,v,w; i < n; i ++) u = read(),v = read(),w = read(),add(u,v,w),add(v,u,w);
	dep[1] = 1;
	for (int i = 0; i <= m; i ++)
	{
		G[i][0] = 1ll;
		for (int j = 1; j <= i; j ++) G[i][j] = (G[i - 1][j] + G[i - 1][j - 1]) % mo;
	}
	ans = 0;
	dfs(1);
	for (int i = 1; i <= 20; i ++) 
		for (int u = 1; u <= n; u ++) fa[u][i] = fa[fa[u][i - 1]][i - 1];
	ans = ans * 2ll % mo;
	for (int i = 1; i <= m; i ++)
		for (int j = 1; j <= m; j ++) dist[i][j] = get(a[i],a[j]);
	for (int i = 1; i < m; i ++)
		for (int j = i + 1,g; j <= m; j ++)
		{     
			g = 0;
			ll l1 = dist[i][j];
			for (int l = 1; l <= m; l ++)
			{
				ll l2 = dist[i][l],l3 = dist[j][l];
				if ((l1 > l2 || (l1 == l2 && l > j)) && 
					(l1 > l3 || (l1 == l3 && l > i)))
				g ++;
			}
			ans = (ans - l1 % mo * C(g,k - 2) % mo) % mo;
		}
	ans = ans * ksm(C(m,k),mo - 2) % mo;
	printf("%lld\n",(ans + mo) % mo);
	return 0;
}

一道签到题都打死我了。。

你可能感兴趣的:(题解)