先看一道例题:
给定 n,k n , k ,求:
根据高中数学介绍的公式, O(1) O ( 1 ) 算出。
我们已经不能手算了。由 k≤2 k ≤ 2 的情况,猜想答案可能是一个 k+1 k + 1 次多项式。于是可以高斯消元消出每一项系数, O(k3) O ( k 3 ) 得到答案。
具体证明可以参见差分与有限微积分——阮行止的博客。
接下来这一部分,高斯消元已经无能为力了。接下来就要开始进入正题——拉格朗日插值法了。
什么?你已经推出来了?
设 f(k)=∑ni=1ik f ( k ) = ∑ i = 1 n i k ,我们有:
现在,上面那种奇♂怪的算法也不能做了,下面介绍拉格朗日插值法。
所谓插值,就是将一些点值 (xi,yi) ( x i , y i ) 代入反解还原出多项式的过程,上文介绍的高斯消元解出多项式的系数,就是插值的过程。
拉格朗日插值法是一种高效的插值多项式的算法,以以法国数学家约瑟夫·拉格朗日命名。
先回到我们高斯消元解多项式的过程,我们对于一个 n n 次多项式(加上常数项共 n+1 n + 1 项),我们需要 n+1 n + 1 个点值来列方程,由线性代数的知识可得,一个含有 n+1 n + 1 个未知数的线性方程组有唯一解的必要条件是方程的个数 ≤n+1 ≤ n + 1 ,所以我们需要 n+1 n + 1 个点值才能解出唯一对应的多项式。
在拉格朗日插值法中,我们也需要 n+1 n + 1 个点值。设点值为 (x0,y0),(x2,y2),…,(xnyn) ( x 0 , y 0 ) , ( x 2 , y 2 ) , … , ( x n y n )
拉格朗日插值法的原理是构造一个拉格朗日基本多项式 lj(x) l j ( x ) ,满足 lj(xj)=1 l j ( x j ) = 1 , ∀i=0…n,i≠j,lj(xi)=0 ∀ i = 0 … n , i ≠ j , l j ( x i ) = 0 。
那么所得到的拉格朗日插值多项式为:
十分巧妙。
这样朴素计算显然是 O(n2) O ( n 2 ) 的。 但是我们可以取原函数在 0,1,…,n 0 , 1 , … , n 处的取值,这样就把 lj(x) l j ( x ) 的分数线下化成阶乘的形式,分数线上化成下降阶乘幂的形式。于是就可以在 O(nlog2n) O ( n log 2 n ) ( log2n log 2 n 为快速幂)的复杂度来求解这个问题了。
给定 k,a,n,d,p k , a , n , d , p
f(i)=1k+2k+3k+......+ik f ( i ) = 1 k + 2 k + 3 k + . . . . . . + i k
g(x)=f(1)+f(2)+f(3)+....+f(x) g ( x ) = f ( 1 ) + f ( 2 ) + f ( 3 ) + . . . . + f ( x )
求 (g(a)+g(a+d)+g(a+2d)+......+g(a+nd)) modp ( g ( a ) + g ( a + d ) + g ( a + 2 d ) + . . . . . . + g ( a + n d ) ) mod p
1≤k≤123 1 ≤ k ≤ 123
0≤a,n,d≤123456789 0 ≤ a , n , d ≤ 123456789
p=1234567891 p = 1234567891
显然 f(i) f ( i ) 是一个 k+1 k + 1 次多项式。
显然 g(i) g ( i ) 是 f(i) f ( i ) 的前缀和, 是一个 k+2 k + 2 次多项式。
显然 ans a n s 是 g(i) g ( i ) 的前缀和, 是一个 k+3 k + 3 次多项式。
于是就可以大力插值,算出 g g ,然后算出 ans a n s 。
由于中间结果可能超过 p,预处理阶乘搞可能绘算出 0 0 ,于是我们可以暴力插值,同样能通过此题。
复杂度 O(k2) O ( k 2 ) 。
/**************************************************************
Problem: 3453
User: infinityedge
Language: C++
Result: Accepted
Time:520 ms
Memory:1300 kb
****************************************************************/
#include
using namespace std;
typedef long long ll;
const ll mod = 1234567891ll;
ll qpow(ll a, ll b){
ll ret = 1;
for(; b; b >>= 1, a = a * a % mod){
if(b & 1) ret = ret * a % mod;
}
return ret;
}
ll dy[150];
ll fac[150], ifac[150];
void pre(ll k){
for(int i = 1; i <= k + 3; i ++){
dy[i] = dy[i - 1];
for(int j = 1; j <= i; j ++){
dy[i] = (dy[i] + qpow(j, k)) % mod;
}
// printf("%lld ", dy[i]);
}
// printf("\n");
fac[0] = 1; ifac[0] = 1;
for(int i = 1; i <= k + 4; i ++){
fac[i] = fac[i - 1] * i % mod;
ifac[i] = qpow(fac[i], mod - 2);
}
}
ll facx[150], ifacx[150];
ll calG(ll k, ll x){
if(x <= k + 3) return dy[x];
facx[0] = x % mod; ifacx[0] = qpow(x % mod, mod - 2);
for(int i = 1; i <= k + 3; i ++){
facx[i] = facx[i - 1] * (x % mod - i + mod) % mod;
ifacx[i] = qpow(facx[i], mod - 2);
}
ll ret = 0;
for(int i = 1; i <= k + 3; i ++){
ll tmp = 1;
for(int j = 1; j <= k + 3; j ++){
if(j == i) continue;
tmp = tmp * (x % mod - j + mod) % mod;
}
// tmp = tmp * facx[k + 3] * ifacx[i] % mod * facx[i - 1] % mod * ifacx[0] % mod;
tmp = tmp * qpow(fac[i - 1] * fac[k + 3 - i] % mod, mod - 2) % mod;
if((k + 3 - i) % 2 == 1) tmp = mod - tmp;
ret = (ret + tmp * dy[i]) % mod;
}
return ret;
}
//1 3 6 10 15 21
//1 4 10 20 35 56
//1 5 15 35 70
ll y[150];
ll solve(ll k, ll a, ll n, ll d){
for(int i = 0; i <= k + 3; i ++){
y[i] = (y[i - 1] + calG(k, a + d * i)) % mod;
// printf("%lld ", y[i]);
}
// printf("\n");
if(n <= k + 3) return y[n];
ll ret = 0;
facx[0] = n % mod; ifacx[0] = qpow(n, mod - 2);
for(int i = 1; i <= k + 3; i ++){
facx[i] = facx[i - 1] * (n % mod - i + mod) % mod;
ifacx[i] = qpow(facx[i], mod - 2);
// printf("%lld ", facx[i]);
}
// printf("\n");
for(int i = 0; i <= k + 3; i ++){
ll tmp = 1;
tmp = tmp * facx[k + 3] * ifacx[i] % mod;
if(i != 0) tmp = tmp * facx[i - 1] % mod;
if(i == 0) tmp = tmp * qpow(fac[k + 3] % mod, mod - 2) % mod;
else tmp = tmp * qpow(fac[i] * fac[k + 3 - i] % mod, mod - 2) % mod;
if((k + 3 - i) % 2 == 1) tmp = mod - tmp;
ret = (ret + tmp * y[i]) % mod;
//printf("%lld\n", y[i]);
}
return ret;
}
int main(){
int T;
scanf("%d", &T);
while(T--){
ll k, a, n, d;
scanf("%lld%lld%lld%lld", &k, &a, &n, &d);
pre(k);
printf("%lld\n", solve(k, a, n, d));
}
return 0;
}