BZOJ4161: Shlw loves Matrixl 题解

如果把转移写成矩阵的形式做矩阵乘法,复杂度是 O(k3logn) O ( k 3 l o g n ) 的,不足以通过此题
学习了一下这个noi2017用到的技巧:可以用矩阵的特征多项式优化常系数齐次线性递推


前置技能

特征多项式

矩阵 A A 的特征多项式 P(λ)=detλIA P ( λ ) = d e t ∣ λ I − A ∣ ,它满足对于 矩阵 A A 的特征值 λi λ i ,都有 P(λi)=0 P ( λ i ) = 0 (参考特征值的性质)
n阶的常数项线性递推的转移矩阵 A A 可以用以下公式计算特征多项式:

P(λ)=(1)n(λni=1naiλni) P ( λ ) = ( − 1 ) n ( λ n − ∑ i = 1 n a i λ n − i )

证明甩锅,见博客 https://www.cnblogs.com/Troywar/p/9078013.html

Cayley-Hamilton定理

对于矩阵 A A 的特征多项式 P(λ) P ( λ ) P(A)=0 P ( A ) = 0
(这玩意我又不会证,但不是把A代入 λ λ 然后消一消这么简单的)


类比

xn=f(x)g(x)+h(x) x n = f ( x ) g ( x ) + h ( x )

我们把矩阵 A A 视为像 x x 一样的变元,可以写出
An=f(A)g(A)+h(A) A n = f ( A ) g ( A ) + h ( A )

如果我们取 f(A)=P(λ)=0 f ( A ) = P ( λ ) = 0 ,则 Anh(A)(modg(λ)) A n ≡ h ( A ) ( m o d g ( λ ) )
我们知道因为 P(λ) P ( λ ) 是k次多项式,所以 h(A) h ( A ) 是不超过k-1次的多项式, An A n 实际上就是只有n次的那个系数是1其他都是0的一个n次多项式,如果我们用快速幂+暴力多项式取模的话,求出 h(A) h ( A ) 的复杂度是 O(k2logk) O ( k 2 l o g k ) ,如果用FFT加速,复杂度是 O(klogklogn) O ( k l o g k l o g n )
考虑求出 h(A) h ( A ) 以后怎么求出 f(n) f ( n )
考虑
BTi=(f(i+k1),f(i+k2)...f(i+1),f(i)) B i T = ( f ( i + k − 1 ) , f ( i + k − 2 ) . . . f ( i + 1 ) , f ( i ) )

那么
AnB0=Bn=(f(n+k1),f(n+k2)...f(n+1,f(n)) A n B 0 = B n = ( f ( n + k − 1 ) , f ( n + k − 2 ) . . . f ( n + 1 , f ( n ) )

我们之前已经求出了
h(A)=i=0k1hiAi h ( A ) = ∑ i = 0 k − 1 h i A i

的每个 hi h i ,那么
AnB0=h(A)B0=(i=0k1hiAi)B0 A n B 0 = h ( A ) B 0 = ( ∑ i = 0 k − 1 h i A i ) B 0

h(A) h ( A ) 的每一项拆开来和 B0 B 0 乘,会发现
((i=0k1hiAi)B0)T=i=0k1BT0(hiAn)=i=0k1hi(f(i+k2),f(i+k2)...f(i+1),f(i))=(i=0k1hif(i+k1),i=0k1hif(i+k2)i=0k1hif(i))=BTn ( ( ∑ i = 0 k − 1 h i A i ) B 0 ) T = ∑ i = 0 k − 1 B 0 T ( h i A n ) = ∑ i = 0 k − 1 h i ( f ( i + k − 2 ) , f ( i + k − 2 ) . . . f ( i + 1 ) , f ( i ) ) = ( ∑ i = 0 k − 1 h i f ( i + k − 1 ) , ∑ i = 0 k − 1 h i f ( i + k − 2 ) … ∑ i = 0 k − 1 h i f ( i ) ) = B n T

我们要的是 BTn B n T 的最后一项,所以
ans=i=0k1hif(i) a n s = ∑ i = 0 k − 1 h i f ( i )

#include 
using namespace std;

#define LL long long
#define LB long double
#define ull unsigned long long
#define x first
#define y second
#define pb push_back
#define pf push_front
#define mp make_pair
#define Pair pair
#define pLL pair
#define pii pair
#define LOWBIT(x) x & (-x)

const int INF=2e9;
const LL LINF=2e16;
const int magic=348;
const int MOD=1e9+7;
const double eps=1e-10;
const double pi=acos(-1);

inline int getint()
{
    bool f;char ch;int res;
    while (!isdigit(ch=getchar()) && ch!='-') {}
    if (ch=='-') f=false,res=0; else f=true,res=ch-'0';
    while (isdigit(ch=getchar())) res=res*10+ch-'0';
    return f?res:-res;
}

const int MAXN=2000;

int n,k;
int a[MAXN+48],h[MAXN+48];
int g[MAXN+48],inv[MAXN+48];

inline int add(int x) {if (x>=MOD) x-=MOD;return x;}
inline int sub(int x) {if (x<0) x+=MOD;return x;}

inline int quick_pow(int x,int y)
{
    int res=1;
    while (y)
    {
        if (y&1) res=(1ll*res*x)%MOD,y--;
        x=(1ll*x*x)%MOD;y>>=1;
    }
    return res;
}

inline void construct_g()
{
    int i;
    g[k]=1;for (i=k;i>=1;i--) g[k-i]=sub(-a[i]);
    if (k&1) for (i=k;i>=0;i--) g[i]=sub(-g[i]);
    for (i=0;i<=k;i++) inv[i]=quick_pow(g[i],MOD-2);
}

struct poly
{
    int a[MAXN*5+48],len;
    inline void clear() {for (register int i=0;i0;}
    inline poly operator * (poly other)
    {
        poly res;res.len=len+other.len-1;res.clear();
        for (register int i=0;ifor (register int j=0;j1ll*a[i]*other.a[j])%MOD);
        for (register int i=res.len-1;i>=k;i--)
        {
            int tmp=(1ll*inv[k]*res.a[i])%MOD;
            for (register int j=i;j>=i-k;j--)
                res.a[j]=sub(res.a[j]-(1ll*g[j-i+k]*tmp)%MOD);
        }
        if (res.len>k) res.len=k;
        return res;
    }
    inline void print() {for (register int i=0;icout<' ';cout<inline poly quick_pow(poly x,int y)
{
    poly res;res.len=1;res.a[0]=1;
    while (y)
    {
        if (y&1) res=res*x,y--;
        x=x*x;y>>=1;
    }
    return res;
}

int main ()
{
    //freopen ("a.in","r",stdin);
    //freopen ("a.out","w",stdout);
    int i;n=getint();k=getint();
    for (i=1;i<=k;i++) a[i]=getint();
    for (i=0;i<=k-1;i++) h[i]=getint(),h[i]=add(h[i]+MOD);
    if (n<=k-1) {printf("%d\n",add(h[n]+MOD));return 0;}
    construct_g();poly base;base.len=2;base.a[1]=1;base.a[0]=0;
    poly res=quick_pow(base,n);
    int ans=0;
    for (i=0;i<=k-1;i++) ans=add(ans+(1ll*h[i]*res.a[i])%MOD);
    printf("%d\n",ans);
    return 0;
}

你可能感兴趣的:(矩阵,矩阵特征多项式)