【LOJ3075】「2019 集训队互测 Day 3」组合数求和

【题目链接】

  • 点击打开链接

【思路要点】

  • 所求的 f f f 即为 ∑ i = 0 n − 1 ( x + 1 ) i d = ( x + 1 ) n d − 1 ( x + 1 ) d − 1 \sum_{i=0}^{n-1}(x+1)^{id}=\frac{(x+1)^{nd}-1}{(x+1)^d-1} i=0n1(x+1)id=(x+1)d1(x+1)nd1 m m m 项的系数。
  • A ( x ) = ( x + 1 ) n d − 1 x , B ( x ) = ( x + 1 ) d − 1 x , C ( x ) = ∑ i = 0 n − 1 ( x + 1 ) i d A(x)=\frac{(x+1)^{nd}-1}{x},B(x)=\frac{(x+1)^d-1}{x},C(x)=\sum_{i=0}^{n-1}(x+1)^{id} A(x)=x(x+1)nd1,B(x)=x(x+1)d1,C(x)=i=0n1(x+1)id ,那么 A ( x ) = B ( x ) C ( x ) A(x)=B(x)C(x) A(x)=B(x)C(x) ,其中 B ( x ) B(x) B(x) 的次数为 d d d ,其中 C ( x ) C(x) C(x) 的次数为 m m m
  • 考虑计算模数为 p k p^k pk 时的 C ( x ) C(x) C(x) ,最后用中国剩余定理合并答案。
  • 比较 A ( x ) = B ( x ) C ( x ) A(x)=B(x)C(x) A(x)=B(x)C(x) 两端的系数,有 [ x i ] A ( x ) = ∑ j = 0 d − 1 [ x j ] B ( x ) [ x i − j ] C ( x ) [x^i]A(x)=\sum_{j=0}^{d-1}[x^j]B(x)[x^{i-j}]C(x) [xi]A(x)=j=0d1[xj]B(x)[xij]C(x)
  • 找到最小的 t t t ,使得 [ x t ] B ( x ) [x^t]B(x) [xt]B(x) 不是 p p p 的倍数,则有
    [ x t ] B ( x ) [ x i ] C ( x ) = [ x i + t ] A ( x ) − ∑ j = 0 t − 1 [ x j ] B ( x ) [ x i + t − j ] C ( x ) − ∑ j = t + 1 d − 1 [ x j ] B ( x ) [ x i + t − j ] C ( x ) [x^t]B(x)[x^i]C(x)=[x^{i+t}]A(x)-\sum_{j=0}^{t-1}[x^j]B(x)[x^{i+t-j}]C(x)-\sum_{j=t+1}^{d-1}[x^j]B(x)[x^{i+t-j}]C(x) [xt]B(x)[xi]C(x)=[xi+t]A(x)j=0t1[xj]B(x)[xi+tj]C(x)j=t+1d1[xj]B(x)[xi+tj]C(x)
  • 由于 [ x t ] B ( x ) [x^t]B(x) [xt]B(x) 不是 p p p 的倍数,我们可以将等式两侧乘以其乘法逆元。
  • 可以发现,对于 j > i j>i j>i ,上式中 [ x j ] C ( x ) [x^j]C(x) [xj]C(x) 前的系数均为 p p p 的倍数,因此,若让等式右侧的 [ x j ] C ( x ) [x^j]C(x) [xj]C(x) 从高次向低次再次代入上式,可以使得对于 j > i j>i j>i ,得式中 [ x j ] C ( x ) [x^j]C(x) [xj]C(x) 前的系数均为 p 2 p^2 p2 的倍数,重复 k k k 次即可让 [ x i ] C ( x ) [x^i]C(x) [xi]C(x) 的计算式中只包含 A ( x ) A(x) A(x) 中的项,以及 C ( x ) C(x) C(x) 中次数小于 i i i 的项,从而直接计算答案。
  • 计算模数为 p k p^k pk 时的 C ( x ) C(x) C(x) 的时间复杂度为 O ( m d k + d 2 k 2 ) O(mdk+d^2k^2) O(mdk+d2k2)
  • 总时间复杂度 O ( m L o g M + m L o g N + m d L o g M + d 2 L o g 2 M ) O(mLogM+mLogN+mdLogM+d^2Log^2M) O(mLogM+mLogN+mdLogM+d2Log2M)

【代码】

#include
using namespace std;
const int MAXN = 4e6 + 5;
const int MAXQ = 4e3 + 5;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
template <typename T> void chkmax(T &x, T y) {
      x = max(x, y); }
template <typename T> void chkmin(T &x, T y) {
      x = min(x, y); } 
template <typename T> void read(T &x) {
      
	x = 0; int f = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
	for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
	x *= f;
}
template <typename T> void write(T x) {
      
	if (x < 0) x = -x, putchar('-');
	if (x > 9) write(x / 10);
	putchar(x % 10 + '0');
}
template <typename T> void writeln(T x) {
      
	write(x);
	puts("");
}
int n, m, d, P;
vector <pair <int, int>> p;
int power(int x, int y) {
      
	if (y == 0) return 1;
	int tmp = power(x, y / 2);
	if (y % 2 == 0) return 1ll * tmp * tmp % P;
	else return 1ll * tmp * tmp % P * x % P;
}
void exgcd(int a, int b, int &x, int &y) {
      
	if (b == 0) {
      
		x = 1, y = 0;
		return;
	}
	int q = a / b, r = a % b;
	exgcd(b, r, y, x);
	y -= q * x;
}
int inv(int x, int P) {
      
	int a = 0, b = 0;
	exgcd(x, P, a, b);
	return (a % P + P) % P;
}
int crt(int a, int P, int invP, int b, int Q, int invQ) {
      
	int Mod = P * Q;
	return (1ll * a * Q % Mod * invQ + 1ll * b * P % Mod * invP) % Mod;
}
int a[MAXN], b[MAXN], c[MAXN], res[MAXN];
void calcab(int ea, int eb, int m) {
      
	int lft = 1; vector <int> e(p.size());
	for (int i = 0; i < ea && i < m + MAXQ; i++) {
      
		int tmp = ea - i, tnp = i + 1;
		for (unsigned i = 0; i < p.size(); i++) {
      
			while (tmp % p[i].first == 0) {
      
				tmp /= p[i].first;
				e[i]++;
			}
			while (tnp % p[i].first == 0) {
      
				tnp /= p[i].first;
				e[i]--;
			}
		}
		lft = 1ll * lft * tmp % P * inv(tnp, P) % P;
		int res = lft;
		for (unsigned i = 0; i < p.size(); i++)
			res = 1ll * res * power(p[i].first, e[i]) % P;
		a[i] = res;
	}
	lft = 1; e.clear(), e.resize(p.size());
	for (int i = 0; i < eb; i++) {
      
		int tmp = eb - i, tnp = i + 1;
		for (unsigned i = 0; i < p.size(); i++) {
      
			while (tmp % p[i].first == 0) {
      
				tmp /= p[i].first;
				e[i]++;
			}
			while (tnp % p[i].first == 0) {
      
				tnp /= p[i].first;
				e[i]--;
			}
		}
		lft = 1ll * lft * tmp % P * inv(tnp, P) % P;
		int res = lft;
		for (unsigned i = 0; i < p.size(); i++)
			res = 1ll * res * power(p[i].first, e[i]) % P;
		b[i] = res;
	}
}
void factor(int x) {
      
	for (int i = 2; i * i <= x; i++)
		if (x % i == 0) {
      
			int cnt = 0;
			while (x % i == 0) {
      
				x /= i;
				cnt++;
			}
			p.emplace_back(i, cnt);
		}
	if (x != 1) p.emplace_back(x, 1); 
}
void solve(int p, int k, int P) {
      
	int pos = 0;
	while (b[pos] % p == 0) pos++;
	int mul = inv(b[pos], P);
	static int coef[MAXQ], mula[MAXQ], func[MAXQ];
	memset(coef, 0, sizeof(coef));
	memset(mula, 0, sizeof(mula));
	memset(func, 0, sizeof(func));
	for (int i = 0; i <= d - 1; i++) {
      
		if (i != pos) coef[d - 1 - i] = func[d - 1 - i] = 1ll * mul * (P - b[i] % P) % P;
		else mula[d - 1] = mul;
	}
	pos = d - 1 - pos;
	int Limit = pos + d * (k + 1);
	for (int e = 2, v = p * p; e <= k; e++, v *= p)
	for (int i = Limit; i >= pos; i--) {
      
		if (coef[i] % v == 0) continue;
		int tmp = coef[i]; coef[i] = 0;
		mula[i + d - 1 - pos] = (mula[i + d - 1 - pos] + 1ll * tmp * mul) % P;
		for (int j = 0; j <= d - 1; j++)
			coef[i - pos + j] = (coef[i - pos + j] + 1ll * tmp * func[j]) % P;
	}
	while (Limit > pos && mula[Limit] == 0) Limit--;
	for (int i = 0; i <= m - 1; i++) {
      
		int res = 0;
		for (int j = pos; j <= Limit; j++)
			res = (res + 1ll * mula[j] * a[i + j - pos]) % P;
		for (int j = 1; j <= pos && j <= i; j++)
			res = (res + 1ll * coef[pos - j] * c[i - j]) % P;
		c[i] = (res + P) % P;
	}
}
int main() {
      
	read(n), read(m), read(d), read(P);
	factor(P), calcab(n * d, d, m);
	int Mod = 1;
	for (auto x : p) {
      
		int Q = power(x.first, x.second);
		if (p.size() == 1) Q = P;
		solve(x.first, x.second, Q);
		int invMod = inv(Mod % Q, Q);
		int invQ = inv(Q % Mod, Mod);
		for (int i = 0; i <= m - 1; i++)
			res[i] = crt(res[i], Mod, invMod, c[i], Q, invQ);
		Mod *= Q;
	}
	int ans = 0;
	for (int i = 0; i <= m - 1; i++)
		ans ^= res[i];
	writeln(ans);
	return 0;
}

你可能感兴趣的:(【OJ】LOJ,【类型】做题记录,【算法】生成函数,【算法】中国剩余定理)