此篇文章可能写的很不标准,请读者见谅。
常系数齐次线性递推,可以理解为dp,大概是给出 f [ 0.. k − 1 ] f[0..k-1] f[0..k−1]的初值,
对于 i > = k , f [ i ] = ∑ j = 1 k − 1 f [ i − j ] ∗ c [ j ] i>=k,f[i]=\sum_{j=1}^{k-1}f[i-j]*c[j] i>=k,f[i]=∑j=1k−1f[i−j]∗c[j], c c c也是给出的转移系数。
我们的目的是求 f [ n ] f[n] f[n], n n n很大
强套矩阵乘法:
O ( k 3 ∗ l o g n ) O(k^3*logn) O(k3∗logn)
假设我们有一个转移系数 b b b,满足 f [ n ] = ∑ i = 0 n f [ i ] ∗ b [ i ] f[n]=\sum_{i=0}^n f[i]*b[i] f[n]=∑i=0nf[i]∗b[i],我们的终极目的是把 b b b化简,使只有 b [ 0 − k − 1 ] b[0-k-1] b[0−k−1]有值,那么有 f [ n ] = ∑ i = 0 k − 1 f [ i ] ∗ b [ i ] f[n]=\sum_{i=0}^{k-1}f[i]*b[i] f[n]=∑i=0k−1f[i]∗b[i]
显然是存在这样的 b b b的,也很好归纳证明:
因为初值是 f [ 0.. k − 1 ] f[0..k-1] f[0..k−1], f [ 0.. k − 1 ] f[0..k-1] f[0..k−1]肯定有对应的 b b b, f [ k ] f[k] f[k]显然也有对应的 b b b,就是 c c c,再看 f [ k + 1 ] f[k+1] f[k+1],直接由转移式可得它由 f [ 1.. k ] f[1..k] f[1..k]乘上 c c c得来,由于 f [ 1.. k ] f[1..k] f[1..k]都可以由 f [ 0.. k − 1 ] f[0..k-1] f[0..k−1]乘上对应 b b b得来,我们只需要展开一下,就得到了 f [ k + 1 ] f[k+1] f[k+1]对应的 b b b,归纳下去对任意的 f [ n ] f[n] f[n]也成立。
假设现在有一个 b b b,b的最高非0项是 b [ x ] ( x > = k ) b[x](x>=k) b[x](x>=k),考虑怎么把 b [ x ] b[x] b[x]弄掉呢?
显然就是 b [ x − i ] + = b [ x ] ∗ c [ i ] ( 1 < = i < k ) b[x-i]+=b[x]*c[i](1<=i<k) b[x−i]+=b[x]∗c[i](1<=i<k),即把b的影响分到 b [ x − 1.. x − ( k − 1 ) ] b[x-1..x-(k-1)] b[x−1..x−(k−1)]去
发现这和取模的过程十分类似,
回想一下多项式取模的暴力过程,取最高项系数之商,然后被除式减去除式乘商
我们可以构造多项式 M M M:
M [ x k ] = 1 M[x^k]=1 M[xk]=1
M [ x i ] = − c [ k − i ] ( 1 < = i < k ) M[x^i]=-c[k-i](1<=i<k) M[xi]=−c[k−i](1<=i<k)
那么对任意合法 b b b,直接 M o d Mod Mod一下 M M M就可以得到简化后的式子。
那么 b b b怎么求呢?
最简单的b不就是 x n x^n xn吗?
设多项式 A = x 1 A=x^1 A=x1,我们就是求 A n M o d M A^n~Mod~M An Mod M
暴力 M o d Mod Mod:
O ( k 2 ∗ l o g n ) O(k^2*log~n) O(k2∗log n)
多项式取 M o d Mod Mod:
O ( k ∗ l o g k ∗ l o g n ) O(k*log~k*log~n) O(k∗log k∗log n)
https://www.luogu.org/problemnew/show/P4723
Code:
#include
#include
#include
#include
#define ll long long
#define fo(i, x, y) for(int i = x, B = y; i <= B; i ++)
#define ff(i, x, y) for(int i = x, B = y; i < B; i ++)
#define fd(i, x, y) for(int i = x, B = y; i >= B; i --)
#define pp printf
#define hh pp("\n")
#define cp(a, b, c) fo(i, b, c) cout << a[i] << ' '
#define pb push_back
using namespace std;
const int mo = 998244353;
ll ksm(ll x, ll y) {
ll s = 1;
for(; y; y /= 2, x = x * x % mo)
if(y & 1) s = s * x % mo;
return s;
}
const int N = 1 << 18;
ll a[N], b[N]; int r[N];
void dft(ll *a, int tp, int F) {
int n = 1 << tp;
ff(i, 0, n) if(r[i] > i) swap(a[r[i]], a[i]);
for(int h = 1; h < n; h *= 2) {
ll wn = ksm(ksm(3, (mo - 1) / h / 2), F == 1 ? 1 : mo - 2);
for(int j = 0; j < n; j += 2 * h) {
ll A, *l = a + j, *r = a + j + h, w = 1;
ff(i, 0, h) {
A = *r * w, *r = (*l - A) % mo, *l = (*l + A) % mo;
w = w * wn % mo, l ++, r ++;
}
}
}
if(F == -1) {
ll v = ksm(n, mo - 2);
ff(i, 0, n) a[i] = (a[i] + mo) * v % mo;
}
}
void fft(ll *a, ll *b, int tp) {
int n = 1 << tp;
ff(i, 0, n) r[i] = r[i / 2] / 2 + (i & 1) * n / 2;
dft(a, tp, 1); dft(b, tp, 1);
ff(i, 0, n) (a[i] *= b[i]) %= mo;
dft(a, tp, -1);
}
typedef vector<int> V;
V operator *(V p, V q) {
int n = p.size() + q.size() - 1;
int tp = 0; while(1 << ++ tp < n);
ff(i, 0, 1 << tp) a[i] = b[i] = 0;
ff(i, 0, p.size()) a[i] = p[i];
ff(i, 0, q.size()) b[i] = q[i];
fft(a, b, tp); p.clear();
ff(i, 0, n) p.pb(a[i]);
return p;
}
V operator + (V p, V q) {
while(p.size() < q.size()) p.pb(0);
ff(i, 0, q.size()) p[i] += q[i], p[i] > mo ? p[i] -= mo : 0;
return p;
}
V operator - (V p, V q) {
while(p.size() < q.size()) p.pb(0);
ff(i, 0, q.size()) p[i] -= q[i], p[i] < 0 ? p[i] += mo : 0;
return p;
}
V operator * (V p, int q) {
ff(i, 0, p.size()) p[i] = (ll) p[i] * q % mo;
return p;
}
V ni(V a) {
int n = a.size();
int tp = 0; while(1 << ++ tp < n);
V b; b.clear(); b.pb(ksm(a[0], mo - 2));
fo(i, 1, tp) {
V c = a; c.resize(1 << i);
c = c * b * b; c.resize(1 << i);
b = b * 2 - c;
}
b.resize(n); return b;
}
void fan(V &a) {
reverse(a.begin(), a.end());
}
V div(V a, V b) {
int n = a.size() - b.size() + 1;
fan(a); fan(b);
b.resize(a.size()); b = ni(b);
a = a * b; a.resize(n);
fan(a);
return a;
}
void mmo(V &a, V &b) {
if(a.size() < b.size()) return;
V c = div(a, b);
a = a - (c * b); a.resize(b.size() - 1);
}
const int M = 50005;
int n, k;
ll f[M], A[M];
V s, x, q;
int main() {
scanf("%d %d", &n, &k);
fo(i, 1, k) scanf("%lld", &f[i]);
fo(i, 0, k - 1) scanf("%lld", &A[i]);
fd(i, k, 1) q.pb(-f[i]); q.pb(1);
s.pb(1); x.resize(2); x[1] = 1;
for(; n; n /= 2) {
if(n & 1) {
s = s * x; mmo(s, q);
}
x = x * x; mmo(x, q);
}
ll ans = 0;
ff(i, 0, s.size())
ans += s[i] * A[i] % mo;
pp("%lld\n", (ans % mo + mo) % mo);
}