常系数齐次线性递推(知识总结+板子整理)

心得

其实挑战上提了一句这个,只滚系数是O(k²logn)的,但是一直不会实现

今天终于把代码啃下来了,以后带着板子抄就可以了……

适用于dp递推式,O(k³logn)矩阵快速幂超时的场合,

可O(k²)的BM求线性递推式或O(k²logn)的系数矩阵快速幂

思路来源

https://ac.nowcoder.com/acm/contest/view-submission?submissionId=40893084杜教AC代码

https://wenku.baidu.com/view/bac23be1c8d376eeafaa3111.html(叉姐的论文《线性递推关系与矩阵乘法》)

知识整理

如果a_{k}=c_{0}a_{0}+c_{1}a_{1}+...+c_{k-1}a_{k-1}

那么a_{n}=c_{0}a_{n-k}+c_{1}a_{n-k+1}+...+c_{k-1}a_{n-1},倒着往回推系数,

迟早能推成全是a_{0},a_{1},...,a_{k-1}表示的项,再用对应系数一乘一求和就得到一个an的答案

所以思路就是,若要将an用a_{0},a_{1},...,a_{k-1}表示,

就要分治的将a_{\frac{n}{2}}a_{0},a_{1},...,a_{k-1}表示,a_{n-\frac{n}{2}}a_{0},a_{1},...,a_{k-1}表示,

然后a_{\frac{n}{2}}a_{n-\frac{n}{2}}对应系数一乘乘出a_{0},a_{1},...,a_{2k-2}的系数,

再把a_{k},a_{k+1},...,a_{2k-2}这些项每一个用a_{n}=c_{0}a_{n-k}+c_{1}a_{n-k+1}+...+c_{k-1}a_{n-1}这样的式子,

倒着从a_{2k-2}下放系数,直到所有的系数都是用a_{0},a_{1},...,a_{k-1}表示的为止

由于,第一次构造的是x^{1}的k阶表示,后续不断自乘才得到最高的k阶表示

所以,代码中w和x的作用就是交换高低位,w是不大于n的最大的2的次幂

而当w当前最高位为1时,x当前最低位为1,b是判断现在n中是否有w的这一位

如果有,执行形如x^{5}=(x^{2})^{2}*x^{1},实际操作时把v[i]*v[j]的结果直接加到u[i+j+1]上,

其效果等同于先加到u[i+j]上然后向高一位乘v[1],因为这里v[1]==1恒成立,就省略了

板子整理

以2019牛客暑期多校训练营(第二场)B题.Eddy Walker 2为例

k<=1050,n<=1e18,只能用O(k²)的BM或O(k²logn)的系数矩阵快速幂

#include 
using namespace std;
#define rep(i,n) for(int i=1;i<=n;++i)
#define mp make_pair
#define pb push_back
#define x0 gtmsub
#define y0 gtmshb
#define x1 gtmjtjl
#define y1 gtmsf
typedef long long ll;
//M为递推项系数个数 c为系数数组
//最终递推式为a[m]=c[0]a[0]+c[1]a[1]+...+c[m-1]a[m-1]  
const int M=1050,P=1000000007;
//求快速幂 只写一个系数即求逆元 
ll pw(ll x,ll y=P-2){
    ll s=1;
    for(;y;y>>=1,x=1ll*x*x%P)
        if(y&1)s=1ll*s*x%P;
    return s;
}
ll i,w,x,b,j,t,a[M],c[M],v[M],u[M<<1],ans;

//推a1...ak每项最终系数的矩阵快速幂 复杂度O(k^2 logn) 
//求a^n的k个a^1 a^k的系数 即求a^(n/2)的k个系数 然后乘在一起变成a^1 到a^(2k)的系数
//然后暴力把a^(k+1)到a^(2k)的系数再倒序下放到a^1 到 a^k上 
ll sol(ll n,ll m) {//求a[n] a[n]来自前m项递推式 
    //scanf("%d%d",&n,&m);
    n+=m-1;//整体右移m-1项 便于0-(m-1)向负的找系数 应为0 
    for(i=m-1;~i;i--)c[i]=pw(m);//c[m-1]到c[0] 每个1/m 相当于x^1  
    for(i=0;i1;i>>=1)w<<=1;//n=0时w=0 n!=0时w>1 w为不大于n的最大2的次幂 
    for(x=0;w;copy(u,u+m,v),w>>=1,x<<=1){//copy把[u,u+m]复制给v 
        fill_n(u,m<<1,0),b=!!(n&w),x|=b;//fill_n 把u.begin()的连续m<<1个位置 都覆盖成0 
        //如果n&w==0 b=0;n&w==w b=1 用两个!把非空判成了1 
        //如果w最高位为1 则x最低位|=1 
        if (x=m;i--)for(j=0,t=i-m;j

 

你可能感兴趣的:(知识点总结)