NOI2019 机器人(多项式+dp)

题目链接

题解

首先可以考虑暴力区间 dp, f i , j , k f_{i,j,k} fi,j,k 表示区间 [ i , j ] [i,j] [i,j] 中,所有数字都不超过 k k k 的方案数。显然可以枚举 l l l,使 ∣ ( l − i ) − ( j − l ) ∣ ≤ 2 |(l-i)-(j-l)| \le 2 (li)(jl)2 f i , j , k = f i , j , k − 1 + ∑ l f i , l − 1 , k f l + 1 , j , k − 1 f_{i,j,k}=f_{i,j,k-1}+\sum_lf_{i,l-1,k}f_{l+1,j,k-1} fi,j,k=fi,j,k1+lfi,l1,kfl+1,j,k1,于是就得到了一个 O ( n 2 B ) O(n^2B) O(n2B) 的做法,但注意到合法区间的个数不会很多,用记忆化搜索可以做到大概 O ( 10 n B ) O(10nB) O(10nB) 级别。

不妨再考虑 A = 1 , B = 1 0 9 A=1,B=10^9 A=1,B=109,对于所有区间的限制都一样,可以猜测答案是和 B B B 有关的多项式,显然 i = j i=j i=j 时, f f f 是一个一次函数,乘法时会对 l − i l-i li 次和 j − l j-l jl 次多项式做乘法,得到一个 i − j i-j ij 次的多项式。但由于要求和,次数会加 1,于是任意 dp 数组就是一个关于 k k k i − j + 1 i-j+1 ij+1 次多项式。

但如果 A , B A,B A,B 任意,可以猜想到这是关于 k k k 的分段多项式,转移的话大概要求 ∑ i = 1 x i k \sum_{i=1}^x i^k i=1xik 这种东西,即一个多项式的“离散积分”。这里有两种维护方法,第一种就是暴力用伯努利数等算出上面多项式的值,每次计算积分 O ( n 2 ) O(n^2) O(n2) 暴力计算一项对答案的贡献即可。

第二种方法是维护下降幂多项式,即把一个多项式表示成 ∑ a i x i ‾ \sum a_ix^{\underline i} aixi 的形式,由于下降幂便于计算积分,即从 1 积到 x − 1 x-1 x1 的结果为 ∑ a i i + 1 x i + 1 ‾ \sum \frac{a_i}{i+1}x^{\underline{i+1}} i+1aixi+1,一次积分的复杂度仅为 O ( n ) O(n) O(n)。下降幂多项式的乘法可以考虑对于第一个多项式的每一项 a i x i ‾ a_ix^{\underline i} aixi,把第二个多项式表示成 ∑ b i ( x − i ) i ‾ \sum b_i(x-i)^{\underline i} bi(xi)i 的形式,就可以直接计算乘法了。后者的系数可以直接由 x i ‾ = ( x − 1 ) i ‾ + i ( x − 1 ) i − 1 ‾ x^{\underline i}=(x-1)^{\underline i}+i(x-1)^{\underline{i-1}} xi=(x1)i+i(x1)i1 递推,因此乘法的复杂度为 O ( n 2 ) O(n^2) O(n2)

因此我们需要支持:把两个分段多项式相加、相乘,把一个分段多项式求离散积分,把一个分段多项式 < l <l <l > r >r >r 的部分清零。这些操作均可以在 O ( O( O(段数 × \times ×次数 2 ) ^2) 2) 的时间内解决。使用记忆化搜索即可通过。

#include 
using namespace std;
typedef long long ll;

const int MAXN = 305, MOD = 1000000007;
ll modpow(ll a, int b) {
	ll res = 1;
	for (; b; b >>= 1) {
		if (b & 1) res = res * a % MOD;
		a = a * a % MOD;
	}
	return res;
}
ll inv[MAXN], tmp[MAXN];
struct Poly {
	vector<int> po; int n;
	Poly(int m = 0) {
		po.resize(n = m);
	}
	int &operator[](int x) { return po[x]; }
	const int &operator[](int x) const { return po[x]; }
	Poly operator+(const Poly &p) const {
		Poly r(max(n, p.n));
		for (int i = 0; i < r.n; i++)
			r[i] = ((i < n ? po[i] : 0) + (i < p.n ? p[i] : 0)) % MOD;
		return r;
	}
	Poly operator*(const Poly &p) const {
		Poly r(n + p.n - 1);
		for (int i = 0; i < p.n; i++) tmp[i] = p[i];
		for (int i = 0; i < n; i++) {
			for (int j = 0; j < p.n; j++)
				r[i + j] = (tmp[j] * po[i] + r[i + j]) % MOD;
			for (int j = 1; j < p.n; j++)
				tmp[j - 1] = (tmp[j - 1] + tmp[j] * j) % MOD;
		}
		return r;
	}
	Poly inter() const {
		Poly r(n + 1);
		for (int i = 0; i < n; i++)
			r[i + 1] = po[i] * inv[i + 1] % MOD;
		return r;
	}
	ll operator()(int x) const {
		ll res = 0, mul = 1;
		for (int i = 0; i < n; i++) {
			res = (res + mul * po[i]) % MOD;
			mul = mul * (x - i + MOD) % MOD;
		}
		return res;
	}
};
struct Func {
	vector<Poly> p;
	vector<int> x;
	Func op(const Func &f, int tp) const {
		Func r; int sa = x.size(), sb = f.x.size();
		for (int a = 0, b = 0, pos = 0;;) {
			r.x.push_back(pos);
			r.p.push_back(tp ? p[a] * f.p[b] : p[a] + f.p[b]);
			if (a + 1 == sa && b + 1 == sb) break;
			if (b + 1 == sb || (a + 1 < sa && x[a + 1] < f.x[b + 1])) pos = x[++a];
			else pos = f.x[++b];
			while (a + 1 < sa && x[a + 1] <= pos) ++a;
			while (b + 1 < sb && f.x[b + 1] <= pos) ++b;
		}
		return r;
	}
	Func lim(int a, int b) const {
		Func r; Poly o(1);
		r.p.push_back(o), r.x.push_back(0);
		int t = x.size();
		for (int i = 0; i < t; i++)
			if (x[i] <= b && (i + 1 == t || x[i + 1] > a)) {
				r.p.push_back(p[i]);
				r.x.push_back(max(x[i], a));
			}
		r.x.push_back(b + 1);
		r.p.push_back(o);
		return r;
	}
	Func inter() {
		int t = x.size();
		Func f;
		for (int i = 0; i < t; i++) {
			f.p.push_back(p[i].inter());
			f.x.push_back(x[i]);
			if (i > 0) f.p[i][0] =
				(f.p[i][0] - f.p[i](x[i]) + f.p[i - 1](x[i]) + MOD) % MOD;
		}
		return f;
	}
	ll sum() { return inter().p.back()[0]; }
} f[10005];
int vis[305][305], A[305], B[305], tot, n;
void dfs(int l, int r) {
	if (vis[l][r]) return;
	int now = vis[l][r] = ++tot;
	Poly p(1);
	f[tot].p.push_back(p);
	f[tot].x.push_back(0);
	if (l > r) {
		f[tot].p[0][0] = 1;
		return;
	}
	for (int i = l; i <= r; i++) if (abs((r - i) - (i - l)) <= 2) {
		dfs(l, i - 1);
		dfs(i + 1, r);
		Func a = f[vis[l][i - 1]], b = f[vis[i + 1][r]];
		if (l < i) a = a.op(a.inter(), 0);
		if (i < r) b = b.inter();
		f[now] = f[now].op(a.op(b, 1).lim(A[i], B[i]), 0);
	}
}
int main() {
    freopen("robot.in", "r", stdin);
    freopen("robot.out", "w", stdout);
	scanf("%d", &n);
	for (int i = 1; i <= n; i++) scanf("%d%d", A + i, B + i);
	inv[1] = 1;
	for (int i = 2; i < 305; i++) inv[i] = MOD - (ll)MOD / i * inv[MOD % i] % MOD;
	dfs(1, n);
	printf("%lld\n", f[vis[1][n]].sum());
	return 0;
}

你可能感兴趣的:(多项式,dp)