k次幂前缀和与拉格朗日插值

引入

先看一道例题:
给定 n,k n , k ,求:

i=1nik ∑ i = 1 n i k

对大质数取模。 n109 n ≤ 10 9
Part 1: k2 k ≤ 2
Part 2: k300 k ≤ 300
Part 3: k2000 k ≤ 2000
Part 4: k105 k ≤ 10 5
先看每一部分的解题方法。

Part 1

根据高中数学介绍的公式, O(1) O ( 1 ) 算出。

Part 2

我们已经不能手算了。由 k2 k ≤ 2 的情况,猜想答案可能是一个 k+1 k + 1 次多项式。于是可以高斯消元消出每一项系数, O(k3) O ( k 3 ) 得到答案。
具体证明可以参见差分与有限微积分——阮行止的博客。

Part 3

接下来这一部分,高斯消元已经无能为力了。接下来就要开始进入正题——拉格朗日插值法了。
什么?你已经推出来了?
f(k)=ni=1ik f ( k ) = ∑ i = 1 n i k ,我们有:

(n+1)k+1nk+1=i=0k(k+1i)ni ( n + 1 ) k + 1 − n k + 1 = ∑ i = 0 k ( k + 1 i ) n i

(n)k+1(n1)k+1=i=0k(k+1i)(n1)i ( n ) k + 1 − ( n − 1 ) k + 1 = ∑ i = 0 k ( k + 1 i ) ( n − 1 ) i


2k+11k+1=i=0k(k+1i)1i 2 k + 1 − 1 k + 1 = ∑ i = 0 k ( k + 1 i ) 1 i

将所有式子求和,得:
(n+1)k+11=i=0k(k+1i)f(i) ( n + 1 ) k + 1 − 1 = ∑ i = 0 k ( k + 1 i ) f ( i )

f(k)=(n+1)k+11k1i=0(k+1i)f(i)(k+1k) f ( k ) = ( n + 1 ) k + 1 − 1 − ∑ i = 0 k − 1 ( k + 1 i ) f ( i ) ( k + 1 k )

然后就可以 O(n2) O ( n 2 ) 算了。

Part 4

现在,上面那种奇♂怪的算法也不能做了,下面介绍拉格朗日插值法。

算法原理

所谓插值,就是将一些点值 (xi,yi) ( x i , y i ) 代入反解还原出多项式的过程,上文介绍的高斯消元解出多项式的系数,就是插值的过程。
拉格朗日插值法是一种高效的插值多项式的算法,以以法国数学家约瑟夫·拉格朗日命名。

先回到我们高斯消元解多项式的过程,我们对于一个 n n 次多项式(加上常数项共 n+1 n + 1 项),我们需要 n+1 n + 1 个点值来列方程,由线性代数的知识可得,一个含有 n+1 n + 1 个未知数的线性方程组有唯一解的必要条件是方程的个数 n+1 ≤ n + 1 ,所以我们需要 n+1 n + 1 个点值才能解出唯一对应的多项式。
在拉格朗日插值法中,我们也需要 n+1 n + 1 个点值。设点值为 (x0,y0),(x2,y2),,(xnyn) ( x 0 , y 0 ) , ( x 2 , y 2 ) , … , ( x n y n )

拉格朗日插值法的原理是构造一个拉格朗日基本多项式 lj(x) l j ( x ) ,满足 lj(xj)=1 l j ( x j ) = 1 , i=0n,ij,lj(xi)=0 ∀ i = 0 … n , i ≠ j , l j ( x i ) = 0
那么所得到的拉格朗日插值多项式为:

j=0nyjlj(x) ∑ j = 0 n y j l j ( x )

发现这个多项式对于所有的点值均成立。那么我们想办法构造出满足条件的 lj(x) l j ( x ) 。这里用到了一个十分暴力的方法:

lj(x)=i=0,ijnxxixjxi=(xx0)(xjx0)(xxj1)(xjxj1)(xxj+1)(xjxj+1)(xxn)(xjxn) l j ( x ) = ∏ i = 0 , i ≠ j n x − x i x j − x i = ( x − x 0 ) ( x j − x 0 ) … ( x − x j − 1 ) ( x j − x j − 1 ) ( x − x j + 1 ) ( x j − x j + 1 ) … ( x − x n ) ( x j − x n )

十分巧妙。

这样朴素计算显然是 O(n2) O ( n 2 ) 的。 但是我们可以取原函数在 0,1,,n 0 , 1 , … , n 处的取值,这样就把 lj(x) l j ( x ) 的分数线下化成阶乘的形式,分数线上化成下降阶乘幂的形式。于是就可以在 O(nlog2n) O ( n log 2 ⁡ n ) log2n log 2 ⁡ n 为快速幂)的复杂度来求解这个问题了。

例题

「BZOJ3453」「Tyvj1858」 XLkxc

题意:

给定 k,a,n,d,p k , a , n , d , p
f(i)=1k+2k+3k+......+ik f ( i ) = 1 k + 2 k + 3 k + . . . . . . + i k
g(x)=f(1)+f(2)+f(3)+....+f(x) g ( x ) = f ( 1 ) + f ( 2 ) + f ( 3 ) + . . . . + f ( x )
(g(a)+g(a+d)+g(a+2d)+......+g(a+nd)) modp ( g ( a ) + g ( a + d ) + g ( a + 2 d ) + . . . . . . + g ( a + n d ) )   mod p
1k123 1 ≤ k ≤ 123
0a,n,d123456789 0 ≤ a , n , d ≤ 123456789
p=1234567891 p = 1234567891

题解:

显然 f(i) f ( i ) 是一个 k+1 k + 1 次多项式。
显然 g(i) g ( i ) f(i) f ( i ) 的前缀和, 是一个 k+2 k + 2 次多项式。
显然 ans a n s g(i) g ( i ) 的前缀和, 是一个 k+3 k + 3 次多项式。

于是就可以大力插值,算出 g g ,然后算出 ans a n s
由于中间结果可能超过 p,预处理阶乘搞可能绘算出 0 0 ,于是我们可以暴力插值,同样能通过此题。
复杂度 O(k2) O ( k 2 )

My Code
/**************************************************************
    Problem: 3453
    User: infinityedge
    Language: C++
    Result: Accepted
    Time:520 ms
    Memory:1300 kb
****************************************************************/

#include 

using namespace std;
typedef long long ll;
const ll mod = 1234567891ll;



ll qpow(ll a, ll b){
    ll ret = 1;
    for(; b; b >>= 1, a = a * a % mod){
        if(b & 1) ret = ret * a % mod;
    }
    return ret;
}

ll dy[150];
ll fac[150], ifac[150];
void pre(ll k){
    for(int i = 1; i <= k + 3; i ++){
        dy[i] = dy[i - 1];
        for(int j = 1; j <= i; j ++){
            dy[i] = (dy[i] + qpow(j, k)) % mod;
        }
//      printf("%lld ", dy[i]);
    }
//      printf("\n");
    fac[0] = 1; ifac[0] = 1;
    for(int i = 1; i <= k + 4; i ++){
        fac[i] = fac[i - 1] * i % mod;
        ifac[i] = qpow(fac[i], mod - 2);
    }
}
ll facx[150], ifacx[150];
ll calG(ll k, ll x){
    if(x <= k + 3) return dy[x];
    facx[0] = x % mod; ifacx[0] = qpow(x % mod, mod - 2);
    for(int i = 1; i <= k + 3; i ++){
        facx[i] = facx[i - 1] * (x % mod - i + mod) % mod;
        ifacx[i] = qpow(facx[i], mod - 2);
    }
    ll ret = 0;
    for(int i = 1; i <= k + 3; i ++){
        ll tmp = 1;
        for(int j = 1; j <= k + 3; j ++){
            if(j == i) continue;
            tmp = tmp * (x % mod - j + mod) % mod;
        }
    //  tmp = tmp * facx[k + 3] * ifacx[i] % mod * facx[i - 1] % mod * ifacx[0] % mod;
        tmp = tmp * qpow(fac[i - 1] * fac[k + 3 - i] % mod, mod - 2) % mod;
        if((k + 3 - i) % 2 == 1) tmp = mod - tmp;
        ret = (ret + tmp * dy[i]) % mod;
    }
    return ret;
}
//1 3 6 10 15 21
//1 4 10 20 35 56
//1 5 15 35 70
ll y[150];
ll solve(ll k, ll a, ll n, ll d){
    for(int i = 0; i <= k + 3; i ++){
        y[i] = (y[i - 1] + calG(k, a + d * i)) % mod;
//      printf("%lld ", y[i]);
    }
//  printf("\n");
    if(n <= k + 3) return y[n];
    ll ret = 0;
    facx[0] = n % mod; ifacx[0] = qpow(n, mod - 2);
    for(int i = 1; i <= k + 3; i ++){
        facx[i] = facx[i - 1] * (n % mod - i + mod) % mod;
        ifacx[i] = qpow(facx[i], mod - 2);
//      printf("%lld ", facx[i]);
    }
//  printf("\n");
    for(int i = 0; i <= k + 3; i ++){
        ll tmp = 1;
        tmp = tmp * facx[k + 3] * ifacx[i] % mod;
        if(i != 0) tmp = tmp * facx[i - 1] % mod;
        if(i == 0) tmp = tmp * qpow(fac[k + 3] % mod, mod - 2) % mod;
        else tmp = tmp * qpow(fac[i] * fac[k + 3 - i] % mod, mod - 2) % mod;
        if((k + 3 - i) % 2 == 1) tmp = mod - tmp;
        ret = (ret + tmp * y[i]) % mod;
        //printf("%lld\n", y[i]);
    }
    return ret;
}


int main(){
    int T;
    scanf("%d", &T);
    while(T--){
        ll k, a, n, d;
        scanf("%lld%lld%lld%lld", &k, &a, &n, &d);
        pre(k);
        printf("%lld\n", solve(k, a, n, d));
    }
    return 0;
}

你可能感兴趣的:(学习笔记)