“蔚来杯”2022牛客暑假多校训练营部分题解 1

1J. Serval and Essay

题目大意

  • n n n 个点 m m m 条边的无重边无自环的有向图。
  • 初始时可以选择一个点染黑,其余点均为白点。
  • 若某点所有入边起点均为黑点,则该点可被染黑。
  • 最大化图中黑点数目。
  • 多组数据, 1 ≤ ∑ n ≤ 2 × 1 0 5 , 1 ≤ ∑ m ≤ 5 × 1 0 5 1 \le \sum n \le 2 \times 10^5, 1 \le \sum m\le 5\times 10^5 1n2×105,1m5×105

算法 1

  • S u S_u Su 表示 u u u 为初始黑点迭代得到的最终黑点集合, I u I_u Iu u u u 所有入边起点构成的集合。
  • 显然 ∀ u , v \forall u,v u,v S u , S v S_u, S_v Su,Sv 只能为包含或不交的关系。
  • 初始时 S u = { u } S_u = \{u\} Su={u},迭代的过程即,若 ∃ v , I v ⊆ S u \exist v, I_v \subseteq S_u v,IvSu,则 S u ← S u ∪ S v S_u \leftarrow S_u \cup S_v SuSuSv
  • 不断重复上述过程,直到找不到符合条件的 v v v,答案为 max ⁡ { ∣ S u ∣ } \max\{|S_u|\} max{Su}
  • 用并查集维护 S u S_u Su
  • 初始时 I v ⊆ S u I_v \subseteq S_u IvSu 就等价于 S v S_v Sv 仅有一条来自 S u S_u Su 的入边。
  • 可以用 set \text{set} set 维护每个 S u S_u Su 所有出边指向的点的集合(除去指向 S u S_u Su 内部的点),合并 S u , S v S_u, S_v Su,Sv 时合并出边就能使上述等价条件在任意情况下成立,只要维护每个点的入度就能很容易找到点 v v v
  • 合并时采用启发式合并,时间复杂度 O ( ( n + m ) log ⁡ 2 n ) \mathcal O((n + m)\log^2n) O((n+m)log2n)
#include 

template <class T>
inline void read(T &res)
{
	char ch; bool flag = false; res = 0;
	while (ch = getchar(), !isdigit(ch) && ch != '-');
	ch == '-' ? flag = true : res = ch ^ 48;
	while (ch = getchar(), isdigit(ch))
		res = res * 10 + ch - 48;
	flag ? res = -res : 0;
}

template <class T>
inline void put(T x)
{
	if (x > 9)
		put(x / 10);
	putchar(x % 10 + 48);
}

template <class T>
inline void _put(T x)
{
	if (x < 0)
		x = -x, putchar('-');
	put(x);
}

template <class T>
inline void CkMin(T &x, T y) {x > y ? x = y : 0;}
template <class T>
inline void CkMax(T &x, T y) {x < y ? x = y : 0;}
template <class T>
inline T Min(T x, T y) {return x < y ? x : y;}
template <class T>
inline T Max(T x, T y) {return x > y ? x : y;}
template <class T>
inline T Abs(T x) {return x < 0 ? -x : x;}

using std::set;
using std::map;
using std::pair;
using std::string;
using std::vector;
using std::multiset;
using std::priority_queue;

typedef long long ll;
typedef long double ld;
typedef set<int>::iterator it;
const int Maxn = 1e9;
const int N = 2e5 + 5;
int n, top, T_data;
set<int> out[N];
int ind[N], pre[N], stk[N], sze[N], fa[N];

inline int ufs_find(int x)
{
	if (fa[x] != x)
		return fa[x] = ufs_find(fa[x]);
	return x;
}

inline void Merge(int x, int y)
{
	int tx = ufs_find(x), 
		ty = ufs_find(y);
	if (tx == ty)
		return ;
	if (out[tx].size() < out[ty].size())
		std::swap(tx, ty);
	fa[ty] = tx;
	sze[tx] += sze[ty];

	for (it e2 = out[ty].begin(); e2 != out[ty].end(); ++e2)
	{
		y = *e2;
		it e1 = out[tx].find(y);
		if (e1 == out[tx].end())
			out[tx].insert(y);
		else
		{
			--ind[y];
			if (ind[y] == 1)
				stk[++top] = y;
		}
	}
}

int main()
{
	read(T_data);
	for (int t = 1; t <= T_data; ++t)
	{
		read(n);
		top = 0;
		for (int i = 1; i <= n; ++i)
		{
			out[i].clear();
			fa[i] = i;
			sze[i] = 1;
		}
		for (int i = 1, x; i <= n; ++i)
		{
			read(ind[i]);
			for (int j = 1; j <= ind[i]; ++j)
			{
				read(x);
				out[x].insert(i);
			}
			pre[i] = x;
			if (ind[i] == 1)
				stk[++top] = i;
		}
		
		while (top)
		{
			int x = stk[top--];
			Merge(pre[x], x);
		}
		int ans = 0;
		for (int i = 1; i <= n; ++i)
			if (ufs_find(i) == i)
				CkMax(ans, sze[i]);
		printf("Case #%d: %d\n", t, ans);
	}
	return 0;
}

算法 2

  • 注意到若 v ∈ S u v \in S_u vSu,则 ∣ S v ∣ < ∣ S u ∣ |S_v| < |S_u| Sv<Su
  • 随机一个排列 P P P,若 p i ∉ S p 1 ∪ S p 2 ∪ ⋯ ∪ S p i − 1 p_i \notin S_{p_1} \cup S_{p_2} \cup \dots \cup S_{p_{i -1}} pi/Sp1Sp2Spi1,直接暴力求 S p i S_{p_i} Spi
  • S u S_u Su 的包含关系可以构成树形结构,需要暴力求 S p i S_{p_i} Spi 当且仅当它的所有祖先的 S u S_u Su 均还未求出。
  • 容易分析出这一做法的期望时间复杂度为 O ( ( n + m ) log ⁡ n ) \mathcal O((n+m)\log n) O((n+m)logn)
#include 

template <class T>
inline void read(T &res)
{
	char ch; bool flag = false; res = 0;
	while (ch = getchar(), !isdigit(ch) && ch != '-');
	ch == '-' ? flag = true : res = ch ^ 48;
	while (ch = getchar(), isdigit(ch))
		res = res * 10 + ch - 48;
	flag ? res = -res : 0;
}

template <class T>
inline void put(T x)
{
	if (x > 9)
		put(x / 10);
	putchar(x % 10 + 48);
}

template <class T>
inline void _put(T x)
{
	if (x < 0)
		x = -x, putchar('-');
	put(x);
}

template <class T>
inline void CkMin(T &x, T y) {x > y ? x = y : 0;}
template <class T>
inline void CkMax(T &x, T y) {x < y ? x = y : 0;}
template <class T>
inline T Min(T x, T y) {return x < y ? x : y;}
template <class T>
inline T Max(T x, T y) {return x > y ? x : y;}
template <class T>
inline T Abs(T x) {return x < 0 ? -x : x;}

using std::set;
using std::map;
using std::pair;
using std::string;
using std::vector;
using std::multiset;
using std::priority_queue;

typedef long long ll;
typedef long double ld;
const int mod = 998244353;
const int Maxn = 1e9;
const int N = 2e5 + 5;
int T_data, n, qr;
int que[N], p[N], ind[N], _ind[N];
vector<int> e[N];
bool inv[N], vis[N];

inline void add(int &x, int y)
{
	x += y;
	x >= mod ? x -= mod : 0;
}

inline void dec(int &x, int y)
{
	x -= y;
	x < 0 ? x += mod : 0;
}

int main()
{
	srand(time(0));
	read(T_data);
	for (int t = 1; t <= T_data; ++t)
	{
		read(n);
		for (int i = 1; i <= n; ++i)
		{
			vis[i] = inv[i] = false;
			e[i].clear();
		}
		for (int i = 1; i <= n; ++i)
		{
			read(ind[i]);
			_ind[i] = ind[i];
			for (int j = 1, x; j <= ind[i]; ++j)
			{
				read(x);
				e[x].push_back(i);
			}
		}
		for (int i = 1; i <= n; ++i)
			p[i] = i;
		std::random_shuffle(p + 1, p + n + 1);
		int ans = 0;
		for (int i = 1; i <= n; ++i)
			if (!inv[p[i]])
			{
				que[qr = 1] = p[i];
				inv[p[i]] = vis[p[i]] = true;
				for (int j = 1, x; j <= qr; ++j)
				{
					x = que[j];
					for (int y : e[x])
					{
						if (vis[y])
							continue ;
						if (!--ind[y])
						{
							que[++qr] = y;
							inv[y] = vis[y] = true;
						}
					}
				}
				for (int j = 1, x, y; j <= qr; ++j)
				{
					vis[x = que[j]] = false;
					for (int y : e[x])
						ind[y] = _ind[y];
				}
				CkMax(ans, qr);
			}
		printf("Case #%d: %d\n", t, ans);
	}
	return 0;
}

2H. Take the Elevator

题目大意

  • n n n 个人坐电梯,楼共 k k k 层,每人有起始楼层和目标楼层。

  • 电梯有载客量限制 m m m,上升时可以上升到任意层并随时下降,但是下降要一直下降到一层才能才能再上升。

  • 电梯每秒运行一层,忽略换方向和上下人的时间,问电梯最短运行时间。

  • n , m ≤ 2 × 1 0 5 , k ≤ 1 0 9 n,m\le 2\times 10^5, k\le 10^9 n,m2×105,k109

题解

  • 将每个人视作一条线段 [ l , r ] [l, r] [l,r],不难发现,每趟乘电梯向上和向下的人需要分开处理,但其问题本质相同。

  • 考虑直接贪心,显然每趟从一层出发并回到一层都要取当前 r r r 最大的人,即在需要乘电梯向上和向下的人取最大的 r r r 计入答案,运行过程中我们需要尽可能往电梯中塞人。

  • 实际上塞人的过程也是在贪心取 r r r 最大的人,只是需要考虑载客量的限制,每次取的 r r r 一定单调不增。

  • 对于当前高度(当前在乘电梯的人中最小的 r r r,记作 r min ⁡ r_{\min} rmin),我们用两个 set 分别维护电梯中的人和还未乘电梯的人。

    • 若电梯中的人数 < m <m,我们需要在还未乘电梯的人中找到 r r r 最大且 r ≤ r min ⁡ r \le r_{\min} rrmin 的人。

    • 若电梯中的人数 = m = m =m,我们需要在还未乘电梯的人中找到 r r r 最大且 r ≤ l max ⁡ r \le l_{\max} rlmax 的人,即强制使至少一人走出电梯。

    • 每次更新 r min ⁡ r_{\min} rmin 将需要走出电梯的人在对应的 set 中删除。

  • 这样我们每次操作都能使一人乘上电梯,总时间复杂度 O ( n log ⁡ n ) \mathcal O(n\log n) O(nlogn)

#include 

template <class T>
inline void read(T &res)
{
	char ch; bool flag = false; res = 0;
	while (ch = getchar(), !isdigit(ch) && ch != '-');
	ch == '-' ? flag = true : res = ch ^ 48;
	while (ch = getchar(), isdigit(ch))
		res = res * 10 + ch - 48;
	flag ? res = -res : 0;
}

template <class T>
inline void put(T x)
{
	if (x > 9)
		put(x / 10);
	putchar(x % 10 + 48);
}

template <class T>
inline void _put(T x)
{
	if (x < 0)
		x = -x, putchar('-');
	put(x);
}

template <class T>
inline void CkMin(T &x, T y) {x > y ? x = y : 0;}
template <class T>
inline void CkMax(T &x, T y) {x < y ? x = y : 0;}
template <class T>
inline T Min(T x, T y) {return x < y ? x : y;}
template <class T>
inline T Max(T x, T y) {return x > y ? x : y;}
template <class T>
inline T Abs(T x) {return x < 0 ? -x : x;}

using std::set;
using std::map;
using std::pair;
using std::string;
using std::vector;
using std::multiset;
using std::priority_queue;

typedef long long ll;
typedef long double ld;
const ld pi = acos(-1.0);
const ld eps = 1e-8;
const int N = 4e5 + 5;
const int Maxn = 1e9;
int n, m, k; ll ans;
int a[N], b[N];

struct point
{
	int v, i;
	
	point() {}
	point(int V, int I):
		v(V), i(I) {}
		
	inline bool operator < (const point &a) const 
	{
		return v > a.v || v == a.v && i < a.i;
	}
};
set<point> peo[2], elv[2];
typedef set<point>::iterator it;

inline int solve(int t)
{	
	if (peo[t].begin() == peo[t].end())
		return 0;
	int cnt = 0, h = peo[t].begin()->v, 
		now = k;
	while (1)
	{
		it e = peo[t].lower_bound(point(cnt < m ? now : elv[t].begin()->v, 0));
		if (e == peo[t].end())
			return h;
		elv[t].insert(point(b[e->i], e->i));
		now = a[e->i];
		peo[t].erase(e);
		++cnt, --n;
		while (elv[t].begin()->v >= now)
			elv[t].erase(elv[t].begin()), --cnt;
	}
	return h;
}

int main()
{
	read(n); read(m); read(k);
	for (int i = 1; i <= n; ++i)
	{
		read(a[i]); 
		read(b[i]);
		if (a[i] > b[i])
			peo[0].insert(point(a[i], i));
		else
		{
			std::swap(a[i], b[i]);
			peo[1].insert(point(a[i], i));
		}
	}
	while (n)
	{
		int h1 = solve(0),
			h2 = solve(1);
		ans += 2ll * (Max(h1, h2) - 1);
	}
	std::cout << ans << std::endl;
	
	return 0;
}

3D. Directed

题目大意

  • 给定一棵 n n n 个点的树和一个起点,1 号结点为终点。
  • 随机选择 K K K 条边变成指向终点的单向边,在树上随机游走,求到达终点的期望步数。
  • 1 ≤ n ≤ 1 0 6 , 0 ≤ K ≤ n − 1 1 \le n \le 10^6,0\le K\le n - 1 1n106,0Kn1

题解

  • 先考虑 K = 0 K = 0 K=0 的情况。
  • 结论 给定一棵树,从树中的某点 x x x 出发,设其子树大小为 s i z e x size_x sizex,每次等概率地选择一条相邻的边游走,则第一次游走到其父结点的期望步数为 2 s i z e x − 1 2size_x - 1 2sizex1

证明 设点 x x x 的度为 d e g x deg_x degx,第一次游走到其父结点的期望步数为 f x f_x fx

  • x x x 为叶子结点,显然 f x = 1 = 2 s i z e x − 1 f_x = 1 = 2size_x - 1 fx=1=2sizex1
  • x x x 不为叶子结点,设其子结点为 y 1 , y 2 , … , y k y_1, y_2, \dots, y_k y1,y2,,yk ∀ 1 ≤ i ≤ k , f y i = 2 s i z e y i − 1 \forall 1 \le i \le k, f_{y_i} = 2size_{y_i} - 1 ∀1ik,fyi=2sizeyi1,则

f x = 1 + ∑ i = 1 k ( f y i + f x ) d e g x f x = d e g x + ∑ i = 1 k f y i = 2 ( ∑ i = 1 k s i z e y i + 1 ) − 1 = 2 s i z e x − 1 \begin{aligned}f_x &= 1 + \frac{\sum \limits_{i = 1}^{k}(f_{y_i} + f_x)}{deg_x} \\f_x &= deg_x + \sum \limits_{i = 1}^{k}f_{y_i} \\&= 2(\sum\limits_{i = 1}^{k}size_{y_i} + 1) - 1 \\& = 2size_x - 1\\\end{aligned} fxfx=1+degxi=1k(fyi+fx)=degx+i=1kfyi=2(i=1ksizeyi+1)1=2sizex1

  • 以 1 号结点为根,我们很容易根据上述结论算出答案。
  • 考虑增加一条单向边的影响,设增加的单向边中深度更大的点为 y y y,对于 y y y 的某个祖先 x x x,若 y → x y\to x yx 的路径上没有其它单向边, x x x 第一次游走到其父结点的期望步数会减去 2 s i z e y 2size_y 2sizey
  • 不难发现, y → x y\to x yx 的路径没有其它单向边的概率只与 y → x y \to x yx 的路径长度有关,我们统计出路径长度为特定值的 s i z e y size_y sizey 之和,最后统一计算答案即可。
  • 时间复杂度 O ( n ) \mathcal O(n) O(n)
#include 

template <class T>
inline void read(T &res)
{
	char ch; bool flag = false; res = 0;
	while (ch = getchar(), !isdigit(ch) && ch != '-');
	ch == '-' ? flag = true : res = ch ^ 48;
	while (ch = getchar(), isdigit(ch))
		res = res * 10 + ch - 48;
	flag ? res = -res : 0;
}

template <class T>
inline void put(T x)
{
	if (x > 9)
		put(x / 10);
	putchar(x % 10 + 48);
}

template <class T>
inline void _put(T x)
{
	if (x < 0)
		x = -x, putchar('-');
	put(x);
}

template <class T>
inline void CkMin(T &x, T y) {x > y ? x = y : 0;}
template <class T>
inline void CkMax(T &x, T y) {x < y ? x = y : 0;}
template <class T>
inline T Min(T x, T y) {return x < y ? x : y;}
template <class T>
inline T Max(T x, T y) {return x > y ? x : y;}
template <class T>
inline T Abs(T x) {return x < 0 ? -x : x;}
template <class T>
inline T Sqr(T x) {return x * x;}

using std::map;
using std::set;
using std::pair;
using std::bitset;
using std::string;
using std::vector;
using std::multiset;
using std::priority_queue;

typedef long long ll;
typedef long double ld;
const ld pi = acos(-1.0);
const ld eps = 1e-8;
const int N = 1e6 + 5;
const int Maxn = 1e9;
const int mod = 998244353;
int ans, n, V, K, s;
int sze[N], fa[N], dep[N], pos[N], idx[N];
int cnt[N], fac[N], ifac[N]; bool vis[N];
vector<int> e[N];

inline void add(int &x, int y)
{
	x += y;
	x >= mod ? x -= mod : 0;
}

inline void dec(int &x, int y)
{
	x -= y;
	x < 0 ? x += mod : 0;
}

inline int quick_pow(int x, int k)
{
	int res = 1;
	while (k)
	{
		if (k & 1)
			res = 1ll * res * x % mod;
		x = 1ll * x * x % mod;
		k >>= 1; 
	}
	return res;
}

inline int C(int n, int m)
{
	if (n < 0 || m < 0 || n < m) return 0; 
	return 1ll * fac[n] * ifac[n - m] % mod * ifac[m] % mod;
}

inline void addSubtree(int x, int z)
{
	for (int i = pos[x]; i <= pos[x] + sze[x] - 1; ++i)
	{
		int y = idx[i];
		add(cnt[dep[y] - dep[z]], sze[y]);
		dec(cnt[dep[y] - 1], sze[y]);
	}
}

inline void dfsTraverse(int x)
{
	dep[x] = dep[fa[x]] + 1;
	sze[x] = 1;
	pos[x] = ++V;
	idx[V] = x;
	for (int y : e[x])
	{
		if (y == fa[x])
			continue ;
		fa[y] = x;
		dfsTraverse(y);
		sze[x] += sze[y];
	}
}

int main()
{
	read(n); read(K); read(s);
	for (int i = 1, u, v; i < n; ++i)
	{
		read(u); read(v);
		e[u].push_back(v);
		e[v].push_back(u);
	}
	fac[0] = 1;
	for (int i = 1; i <= n; ++i)
		fac[i] = 1ll * fac[i - 1] * i % mod;
	ifac[n] = quick_pow(fac[n], mod - 2);
	for (int i = n; i >= 1; --i)
		ifac[i - 1] = 1ll * i * ifac[i] % mod;
	
	dfsTraverse(1); 
	while (s != 1)
	{
		add(ans, 2 * sze[s] - 1);
		vis[s] = true;
		add(cnt[1], sze[s]);
		dec(cnt[dep[s] - 1], sze[s]);
		for (int y : e[s])
		{
			if (y == fa[s] || vis[y])	
				continue ;
			addSubtree(y, s);
		}
		s = fa[s];
	}
	for (int i = 1; i <= n - 2; ++i)
		add(cnt[i], cnt[i - 1]);
	int p = quick_pow(C(n - 1, K), mod - 2);
	for (int i = 1; i <= n - 2; ++i)
		dec(ans, 2ll * cnt[i] * p % mod * C(n - 1 - i, K - 1) % mod);
	put(ans), putchar('\n');
	return 0;
}

4A. Task Computing

题目大意

  • 给定长度为 n n n 的序列 w w w p p p
  • 求长度为 m m m 的序列 a a a,满足 1 ≤ a i ≤ n 1 \le a_i \le n 1ain a i a_i ai 两两不同,最大化 ∑ i = 1 m w a i ∏ j = 0 i − 1 p a j \sum \limits_{i = 1}^{m}w_{a_i}\prod\limits_{j = 0}^{i-1}p_{a_j} i=1mwaij=0i1paj
  • 规定 a 0 = 0 , p 0 = 1 a_0 = 0,p_0=1 a0=0,p0=1,且 1 ≤ n ≤ 1 0 5 , 1 ≤ m ≤ min ⁡ { n , 20 } , 1 ≤ w i ≤ 1 0 9 , 0.8 ≤ p i ≤ 1.2 1 \le n \le 10^5,1\le m \le \min\{n,20\}, 1\le w_i \le 10^9, 0.8 \le p_i \le 1.2 1n105,1mmin{n,20},1wi109,0.8pi1.2

题解

  • 这类收益与排列的前缀有关的题目有一定的套路性。
  • 考虑序列 a a a 中相邻两个下标 x x x y y y,此时与交换 x x x y y y 后的收益之差为:
    w a x + w a y p a x − w a y − w a x p a y w_{a_x} + w_{a_y}p_{a_x} - w_{a_y} - w_{a_x}p_{a_y} wax+waypaxwaywaxpay
    ​ 最优解应满足收益之差大于等于 0,稍作移项之后得:
    w a x 1 − p a x ≥ w a y 1 − p a y \frac{w_{a_x}}{1 - p_{a_x}} \ge \frac{w_{a_y}}{1 - p_{a_y}} 1paxwax1payway
    ​ 可将不等式两边的式子视作 a x a_x ax a y a_y ay 的属性,则该式具有传递性(分母等于 0 以及小于 0 的情况需要稍加讨论,但其传递性也是正确的)。
  • 若将序列 w w w p p p 按上式排序,序列 a a a 取新序列的下标时一定递增,方便我们计算收益。
  • f i , j f_{i,j} fi,j 表示已经考虑了新序列的第 i i i 个至第 n n n 个下标、序列 a a a 的后 j j j 个数已确定的最大收益,就能避免正向 DP \text{DP} DP 需要记录 p p p 的前缀积而产生了后效性,不难得到转移:
    f i , j = max ⁡ { f i + 1 , j , f i + 1 , j − 1 × p i + w i } f_{i,j} = \max\{f_{i + 1,j}, f_{i + 1, j - 1} \times p_i + w_i\} fi,j=max{fi+1,j,fi+1,j1×pi+wi}
  • 时间复杂度 O ( n log ⁡ n + n m ) \mathcal O(n \log n + nm) O(nlogn+nm)
#include 

template <class T>
inline void CkMax(T &x, T y) {x < y ? x = y : 0;}

typedef long double ld;
const int N = 1e5 + 3;

struct node
{
	double pi, wi;	
};
node a[N];
int n, m;
ld f[N][25];

inline bool cmp(const node &x, const node &y)
{
	return x.wi + y.wi * x.pi > y.wi + x.wi * y.pi;	
}

int main()
{
	scanf("%d%d", &n, &m);
	for(int i = 1; i <= n; ++i) 
		scanf("%lf", &a[i].wi);
	for(int i = 1; i <= n; ++i)
	{
		scanf("%lf", &a[i].pi);
		a[i].pi /= 10000.0;
	}
	std::sort(a + 1, a + n + 1, cmp);
	ld maxx = 0;
	for (int i = n; i >= 1; --i)
	{
		for (int j = 0; j <= m; ++j)
			f[i][j] = f[i + 1][j];
		for (int j = 1; j <= m; ++j)
			CkMax(f[i][j], f[i + 1][j - 1] * a[i].pi + a[i].wi);
	}
	printf("%.15lf\n", (double)f[1][m]);
    return 0;
}

4C. Easy Counting Problem

题目大意

  • q q q 次询问,每次询问给出 n n n,求符合以下条件的本质不同的字符串个数:
    • 每位的字符为 0 , 1 , … , w − 1 0, 1, \dots, w - 1 0,1,,w1
    • 字符 i i i 出现至少 c i c_i ci 次。
    • 字符串的长度为 n n n
  • 1 ≤ q ≤ 300 , 1 ≤ n ≤ 1 0 7 , 2 ≤ w ≤ 10 , ∑ i = 0 w − 1 c i ≤ 5 × 1 0 4 1 \le q \le 300, 1 \le n \le 10^7,2 \le w \le 10, \sum \limits_{i = 0}^{w - 1}c_i \le 5\times 10^4 1q300,1n107,2w10,i=0w1ci5×104

题解

  • E G F \mathbb{EGF} EGF 可得所求即为:
    n ! [ x n ] ∏ i = 0 w − 1 ( e x − ∑ j = 0 c i − 1 x j j ! ) n![x^n]\prod\limits_{i = 0}^{w - 1}(e^x - \sum \limits_{j = 0}^{c_i - 1}\frac{x^j}{j!}) n![xn]i=0w1(exj=0ci1j!xj)
  • 暴力卷积复杂度太劣,考虑设 y = e x y = e^x y=ex,把后面的乘积看做 y y y 的多项式, x x x 的和式看做这一多项式的系数,系数相乘时用 NTT 优化,其余暴力实现即可。
  • 询问时即查询对于每个 y k y^k yk 对应的系数。
  • c = ∑ i = 0 w − 1 c i c = \sum\limits_{i = 0}^{w - 1}c_i c=i=0w1ci,时间复杂度 O ( w 2 c log ⁡ c + q w c ) \mathcal O(w^2 c\log c + qwc) O(w2clogc+qwc)
#include 

template <class T>
inline void read(T &res)
{
	char ch; bool flag = false; res = 0;
	while (ch = getchar(), !isdigit(ch) && ch != '-');
	ch == '-' ? flag = true : res = ch ^ 48;
	while (ch = getchar(), isdigit(ch))
		res = res * 10 + ch - 48;
	flag ? res = -res : 0;
}

template <class T>
inline void put(T x)
{
	if (x > 9)
		put(x / 10);
	putchar(x % 10 + 48);
}

template <class T>
inline void _put(T x)
{
	if (x < 0)
		x = -x, putchar('-');
	put(x);
}

template <class T>
inline void CkMin(T &x, T y) {x > y ? x = y : 0;}
template <class T>
inline void CkMax(T &x, T y) {x < y ? x = y : 0;}
template <class T>
inline T Min(T x, T y) {return x < y ? x : y;}
template <class T>
inline T Max(T x, T y) {return x > y ? x : y;}
template <class T>
inline T Abs(T x) {return x < 0 ? -x : x;}
template <class T>
inline T Sqr(T x) {return x * x;}

using std::map;
using std::set;
using std::pair;
using std::bitset;
using std::string;
using std::vector;
using std::complex;
using std::multiset;
using std::priority_queue;

typedef long long ll;
typedef long double ld;
typedef complex<ld> com;
typedef pair<int, int> pir;
const ld pi = acos(-1.0);
const ld eps = 1e-8;
const int N = 1e7 + 5;
const int L = 5e4 + 5;
const int L4 = 2e5 + 5;
const int Maxn = 1e9;
const int Minn = -1e9;
const int mod = 998244353;
const int lim = 1e7;
int T_data, n, w, q;
int fac[N], ifac[N];
vector<int> e[12], h[12], a;

inline void add(int &x, int y)
{
	x += y;
	x >= mod ? x -= mod : 0;
}

inline void dec(int &x, int y)
{
	x -= y;
	x < 0 ? x += mod : 0;
}

inline int quick_pow(int x, int k)
{
	int res = 1;
	while (k)
	{
		if (k & 1)
			res = 1ll * res * x % mod;
		x = 1ll * x * x % mod;
		k >>= 1;
	}
	return res;
}

const int inv2 = 499122177;
const int inv3 = 332748118;
int rev[L4], tw[L4], inv[L4];

inline void operator += (vector<int> &a, vector<int> b)
{
	int n = b.size();
	a.resize(Max((int)a.size(), n));
	for (int i = 0; i < n; ++i)
		add(a[i], b[i]);
}

inline void operator -= (vector<int> &a, vector<int> b)
{
	int n = b.size();
	a.resize(Max((int)a.size(), n));
	for (int i = 0; i < n; ++i)
		dec(a[i], b[i]);
}

inline void operator *= (vector<int> &a, int k)
{
	if (k == -1)
	{
		int n = a.size();
		for (int i = 0; i < n; ++i)
			if (a[i])
				a[i] = mod - a[i];
	}
	else 
	{
		int n = a.size();
		for (int i = 0; i < n; ++i)
			a[i] = 1ll * k * a[i] % mod;
	}
}
 
inline void DFT(vector<int> &a, int opt)
{
	int n = a.size(), g = opt == 1 ? 3 : inv3;
	for (int i = 0; i < n; ++i)
		if (i < rev[i])
			std::swap(a[i], a[rev[i]]);
	for (int k = 1; k < n; k <<= 1)
	{
		int w = quick_pow(g, (mod - 1) / (k << 1));
		tw[0] = 1;
		for (int j = 1; j < k; ++j)
			tw[j] = 1ll * tw[j - 1] * w % mod;
		for (int i = 0; i < n; i += k << 1)
		{
			for (int j = 0; j < k; ++j)
			{
				int u = a[i + j],
					v = 1ll * tw[j] * a[i + j + k] % mod;
				add(a[i + j] = u, v);
				dec(a[i + j + k] = u, v);
			}
		}
	}
	if (opt == -1)
	{
		int inv_n = quick_pow(n, mod - 2); 
		for (int i = 0; i < n; ++i)
			a[i] = 1ll * a[i] * inv_n % mod;
	}
}

inline void polyMul(vector<int> &a, vector<int> b)
{
	if (!a.size() || !b.size())
	{
		a.clear();
		return ;
	}
	int m = 0, _n = a.size() + b.size() - 2, n;
	for (n = 1; n <= _n; n <<= 1)
		++m;
	a.resize(n);
	b.resize(n);
	for (int i = 1; i < n; ++i)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (m - 1));
	DFT(a, 1); DFT(b, 1);
	for (int i = 0; i < n; ++i)
		a[i] = 1ll * a[i] * b[i] % mod;
	DFT(a, -1);
	a.resize(_n + 1);
}

int main()
{
	read(w);
	fac[0] = 1;
	for (int i = 1; i <= lim; ++i)
		fac[i] = 1ll * i * fac[i - 1] % mod;
	ifac[lim] = quick_pow(fac[lim], mod - 2);
	for (int i = lim; i >= 1; --i)
		ifac[i - 1] = 1ll * ifac[i] * i % mod;
	
	e[0].push_back(1);
	for (int i = 0, c; i < w; ++i)
	{
		read(c);
		a.resize(c);
		for (int j = 0; j < c; ++j)
			a[j] = ifac[j];
		for (int j = 0; j <= i; ++j)
		{
			h[j + 1] = e[j];
			polyMul(e[j], a);
			e[j] *= -1;
		}
		for (int j = 0; j <= i + 1; ++j)
			e[j] += h[j];
	}
	read(q);
	while (q--)
	{
		read(n);
		int ans = e[0].size() > n ? e[0][n] : 0;
		for (int i = 1; i <= w; ++i)
		{
			int m = Min((int)e[i].size() - 1, n),
				res = quick_pow(i, n - m);
			for (int j = m; j >= 0; --j)
			{
				ans = (1ll * e[i][j] * res % mod * ifac[n - j] + ans) % mod;
				res = 1ll * res * i % mod; 
			}
		}
		put(1ll * fac[n] * ans % mod), putchar('\n');
	}
	return 0;
}

你可能感兴趣的:(多校,算法,数据结构)