多项式取模优化齐次线性递推

线性递推

给出长为 m m 的数列 a=a1,a2,,am a = ⟨ a 1 , a 2 , … , a m ⟩ ,以及无穷数列 f f 的前 m m f0,f1,,fm1 ⟨ f 0 , f 1 , … , f m − 1 ⟩ ,对于 im ∀ i ≥ m ,有 fi=mj=1ajfij f i = ∑ j = 1 m a j f i − j 。求 fn f n
n1018,m30000 n ≤ 10 18 , m ≤ 30000

不同复杂度的做法

  • 暴力 O(mn) O ( m n )
  • 矩阵快速幂 O(m3logn) O ( m 3 log ⁡ n )
  • 暴力多项式快速幂 O(m2logn) O ( m 2 log ⁡ n )
  • nlog n log 的多项式取模写快速幂 O(mlogmlogn) O ( m log ⁡ m log ⁡ n )

前两种暴力就不讲了,这里主要讲一下对“多项式取模为什么可以解线性递推”的理解。

线性递推与多项式取模

这篇博客写得不错,可以参考一下。

这里讲一种不用矩阵来理解的方法。我们直接从计算数列 f f 的方法开始

假设我们对 fn f n 搞出了一个无穷数列 b=b0,b1, b = ⟨ b 0 , b 1 , … ⟩ ,表示 fn f n 可以由这些项求得: fn=ibifi f n = ∑ i b i f i 。当然这样的 b b 有很多,我们很容易搞出一个平凡的:只有第 n n 项为 1 1 ,其余项均为 0 0

现在我们希望把 b b 变短一些,因为如果能把 b b 变换成只有前 m m 项非 0 0 ,我们就可以立即求出 fn f n 了。

b b 的最高非 0 0 项为 bk b k ,那么只要 km k ≥ m bk b k 就可以根据递推式疯狂甩锅,把系数分给前面的项。即 kmi<k,bi+=bkaki ∀ k − m ≤ i < k , b i += b k ∗ a k − i

其实就是把一个 m+1 m + 1 项的常数列 am,am1,,a1,1 ⟨ − a m , − a m − 1 , … , − a 1 , 1 ⟩ 向右平移,并从 b b 中减去它的 bk b k 倍,从而消去了 bk b k

这个过程与暴力多项式取模过程不谋而合。事实上,初中数学常见题型:给一个低次多项式,对一个高次多项式化简求值——用到的其实就是多项式取模,与这里我们对数组 b b 取模的操作完全一样。

所以,把 b b 看做多项式 B=ibixi B = ∑ i b i x i 的系数,设多项式 A=xmm1i=0amixi A = x m − ∑ i = 0 m − 1 a m − i x i 。初始有 B=xn B = x n ,我们只要求出 BmodA B m o d A 。所以我们用 B1=x1 B 1 = x 1 在模 A A 意义下快速幂一下就做完了。就这么简单!

事实上,直接用多项式可以一句话解释清楚为什么多项式取模可以解线性递推:

定义一个多项式 P=ipixi P = ∑ i p i x i 的值为 ipifi ∑ i p i f i 。那么 A A 其实是一个零值恒等式: fmm1i=0amifi=0 f m − ∑ i = 0 m − 1 a m − i f i = 0
Axk A x k 表示把这个恒等式中 f f 的下标向右平移 k k 位,当然还是成立的。
B-=A B -= A 就是系数对位相加减,由于 A A 是零值,当然不改变 B B 的值。
所以,在 B B 中加减 A A 的任意整式倍都不改变值,即 B1=B2(modA)value(B1)=value(B2) B 1 = B 2 ( m o d A ) ⇒ v a l u e ( B 1 ) = v a l u e ( B 2 ) 。因此我们可以在模 A A 意义下求 xn x n 的值。

多项式取模卡常小技巧

多项式取模的具体写法就不在这里讲了,这里假设你已经会了多项式 O(nlogn) O ( n log ⁡ n ) 那一套操作。推荐SemiWaker这篇博客,有详细的讲解和模板代码。另外Miskcoo的数学和多项式都写得非常好,值得一看。

一个常见优化是在多项式求逆这样的倍增时。设 m=n/2 m = ⌈ n / 2 ⌉ ,我们只要知道 [m,n) [ m , n ) 这些项的值,所以循环卷积溢出到 [0,m) [ 0 , m ) 这些项是无所谓的。所以FFT的长度只需要开 1.5n ⌈ 1.5 n ⌉ ,而不是 2n 2 n

最重要的优化是,考虑到我们每次模数都是一样的,而多项式取模的主要常数在于对模数的翻转求逆,这就可以预处理。本来多项式取模常数约是FFT的12倍,这样可以优化到每次取模只做4次DFT,优化了3倍。另外,日常快速幂的多项式乘法也可以优化掉一半的DFT。

第3个小优化是,我们不必从 x1 x 1 开始快速幂,可以从 xm1 x m − 1 开始。考虑到 logmlogn>14 log ⁡ m log ⁡ n > 1 4 ,这样可以变快 20% 20 % 以上。

第4个小优化是,根据定义,我们可以 O(m) O ( m ) 暴力递推一项,从而使 xk x k 变成 xk+1 x k + 1 。所以 x2k+1=xkxkx x 2 k + 1 = x k ⋅ x k ⋅ x 只要做1次模意义下乘法,而不是两次。所以写成递归式的快速幂比迭代式的可以快 30% 30 %

把这些优化加上以后,开头所说的数据范围 n1018,m30000 n ≤ 10 18 , m ≤ 30000 可以在1s内跑出来。

代码实现

#include 
#include 
#include 
typedef long long ll;
const ll MOD = 998244353;
const int PHI = MOD-1, MAX = 1<<17;

inline int get(int n) {
    int N = 1;
    while (N < n) N <<= 1;
    return N;
}

inline ll pow(ll x, int k) {
    ll a = 1;
    do{ if (k & 1) a = a * x %MOD;
        x = x * x %MOD;
    } while (k >>= 1);
    return a;
}
inline ll inv(ll a) {return pow(a, PHI-1);}

ll w[24], _w[24];
void init() {
    w[22] = pow(3, PHI>>23);
    _w[22] = pow((MOD+1)/3, PHI>>23);
    for (int i = 22; i; i--) {
        w[i-1] = w[i] * w[i] % MOD;
        _w[i-1] = _w[i] * _w[i] % MOD;
    }
}

#define DFT(A,N) FFT(N,w)
#define IDFT(A,N) FFT(N,_w)
template bool flag>
void FFT(int N, const ll *w0) {
    for (int i = 1, j = N>>1; i < N; i++) {
        if (i < j) std::swap(A[i], A[j]);
        for (int k = N>>1; (j ^= k) < k; k >>= 1);
    }
    for (int d = 1; d < N; d <<= 1, w0++)
        for (int l = 0; l < N; l += d<<1) {
            ll w = 1;
            for (int i = l; i < l+d; i++) {
                ll tmp = A[i+d] * w;
                A[i+d] = (A[i] - tmp) %MOD;
                A[i] = (A[i] + tmp) %MOD;
                w = w * *w0 %MOD;
            }
        }
    if (flag) {
        ll t = inv(N);
        for (int i = 0; i < N; i++)
            (A[i] *= t) %= MOD;
    }
}

#define clear(A,n) memset(A,0,(n)<<3)
template <const ll A[], ll a[]>
inline void getDFT(int n, int N) {
    std::copy(A, A+n, a);
    clear(a+n, N-n);
    DFT(a, N);
}
inline void mul(ll *A, const ll *B, int N) {
    while (N--) (*A++ *= *B++) %= MOD;
}

ll ta[MAX], tb[MAX], tc[MAX], td[MAX];
template <const ll A[], ll B[]>
void Inv(int n) {
    if (n < 56) {
        *B = inv(*A);
        for (int i = 1; i < n; i++) {
            ll tot = 0;
            for (int j = 0; j < i; j++)
                tot = (tot + B[j] * A[i-j]) %MOD;
            B[i] = -tot * *B %MOD;
        }
        return;
    }
    int m = (n+1)>>1, N = get(n+m-1);
    Inv(m);
    getDFT(n,N);
    getDFT(m,N);
    for (int i = 0; i < N; i++)
        (tb[i] *= 2 - ta[i] * tb[i] %MOD) %= MOD;
    IDFT(tb, N);
    std::copy(tb+m, tb+n, B+m);
}

int m, n, d, N; //(2m-1) mod (m+1) = div (m-1) rem m
template <const ll B[]> void setMod() {
    d = m-1; n = m+d; N = get(n);
    std::reverse_copy(B+1, B+m+1, tc);
    Inv(d);
    clear(td+d, N-d);
    DFT(td, N);
    getDFT(m+1,N);
}

template  void Mod() {
    std::reverse_copy(A+n-d, A+n, tb);
    clear(tb+d, N-d);
    DFT(tb, N);
    mul(tb, td, N);
    IDFT(tb, N);
    std::reverse(tb, tb+d);
    clear(tb+d, N-d);
    DFT(tb, N);
    mul(tb, ta, N);
    IDFT(tb, N);
    for (int i = 0; i < n; i++)
        A[i] -= tb[i];
}

int f[MAX];
ll A[MAX], B[MAX];

void Solve(ll n) {
    if (n < m) {B[n] = 1; return;}
    Solve(n >> 1);
    DFT(B, N);
    mul(B, B, N);
    IDFT(B, N);
    Mod();
    if (n & 1) {
        ll tmp = B[m-1];
        for (int i = m; --i;)
            B[i] = (B[i-1] + tmp * A[i]) %MOD;
        B[0] = tmp * A[0] %MOD;
    }
}

int main() {
    init();
    ll n, ans = 0;
    scanf("%d%lld",&m,&n);
    for (int i = m; i--;)
        scanf("%lld",A+i);
    for (int i = 0; i < m; i++)
        scanf("%d",f+i);
    if (n < m) {printf("%d\n",f[n]); return 0;}
    if (!m) {puts("0"); return 0;}
    if (m == 1) {printf("%lld\n",*f*pow(*A,n%PHI)%MOD); return 0;}
    A[m] = -1;
    setMod();
    Solve(n);
    for (int i = 0; i < m; i++)
        ans = (ans + B[i] * f[i]) %MOD;
    if (ans < 0) ans += MOD;
    printf("%lld\n",ans);
    return 0;
}

你可能感兴趣的:(算法学习笔记,多项式)