多项式快速幂(加强版)

建议阅读我的上一篇博客多项式快速幂

求多项式快速幂,但 a 0 ≠ 1 a_0\not=1 a0=1

由于求 ln ⁡ \ln ln 要求 a 0 = 1 a_0=1 a0=1,所以我们要想办法对多项式进行变换,使其满足 a 0 = 1 a_0=1 a0=1

如果 f ( x ) f(x) f(x) 常数项为 0 0 0,那么就整体除去 x x x 的若干次方,使常数项为 0 0 0

然后再对每项系数除以常数项,这样 a 0 a_0 a0 就等于 1 1 1 了。求法见我上一篇博客。

求出结果后,记得还原回去。

设原函数为 f ( x ) f(x) f(x),变换后的函数为 g ( x ) g(x) g(x),则 f ( x ) k = s k x t k g ( x ) k f(x)^k=s^kx^{tk}g(x)^k f(x)k=skxtkg(x)k s s s f ( x ) f(x) f(x) 从小到大第一个非零系数, t t t 是那一项的次数。

如果 t k ≥ n tk\ge n tkn,答案就是 0 0 0

s k s^k sk 可使用扩展欧拉定理, s k ≡ s k   m o d   φ ( p ) ( m o d p ) s^k\equiv s^{k\bmod \varphi(p)}\pmod p skskmodφ(p)(modp)

参考代码如下

#include
using namespace std;
typedef long long ll;
const int N=(1<<18)+1;
const ll mod=998244353,g=3,inv2=499122177;
int len=1,n;
ll a1[N],w,wn,a[N],ans[N],invans[N],lnans[N],da[N],inva[N],a2[N];
char s[N];
ll ksm(ll a,ll b)
{
    ll ans=1;
    while(b){
        if(b&1) ans=ans*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return ans;
}
void change(ll num[])
{
    for(int i=1,j=len/2;i<len-1;i++){
        if(i<j) swap(num[i],num[j]);
        int k=len/2;
        while(j>=k) j-=k,k>>=1;
        if(j<k) j+=k;
    }
}
void ntt(ll num[],int fl)
{
    for(int i=2;i<=len;i<<=1){
        if(fl==1) wn=ksm(g,(mod-1)/i);
        else wn=ksm(g,mod-1-(mod-1)/i);
        for(int j=0;j<len;j+=i){
            w=1;
            for(int k=j;k<j+i/2;k++){
                ll u=w*num[k+i/2]%mod,t=num[k];
                num[k]=(t+u)%mod;
                num[k+i/2]=(t-u+mod)%mod;
                w=w*wn%mod;
            }
        }
    }
    if(fl==-1){
        ll inv=ksm(len,mod-2);
        for(int i=0;i<len;i++) num[i]=num[i]*inv%mod;
    }
}
void getinv(int n,ll a[],ll ans[])
{
	if(n==1){ans[0]=ksm(a[0],mod-2);return;}
	getinv((n+1)/2,a,ans);
	len=1;
	while(len<2*n) len*=2;
	for(int i=0;i<n;i++) a1[i]=a[i];
	for(int i=n;i<len;i++) a1[i]=0;
	change(a1),change(ans);
	ntt(a1,1),ntt(ans,1);
	for(int i=0;i<len;i++) ans[i]=ans[i]*(2-ans[i]*a1[i]%mod+mod)%mod;
	change(ans),ntt(ans,-1);
	for(int i=n;i<len;i++) ans[i]=0;
}
void getln(int n,ll a[],ll ln[])
{
    memset(da,0,sizeof(da));
	for(int i=1;i<n;i++) da[i-1]=a[i]*i%mod;
    da[n-1]=0;
    memset(inva,0,sizeof(inva));
	getinv(n,a,inva);
	len=1;
	while(len<2*n) len*=2;
	change(da),change(inva);
	ntt(da,1),ntt(inva,1);
	for(int i=0;i<len;i++) ln[i]=da[i]*inva[i]%mod;
	change(ln),ntt(ln,-1);
	for(int i=len-1;i>=0;i--) ln[i+1]=ksm(i+1,mod-2)*ln[i]%mod;
    for(int i=n;i<len;i++) ln[i]=0;
	ln[0]=0;
}
void getsqrt(int n,ll a[],ll ans[])
{
    if(n==1){ans[0]=a[0];return;}
    getsqrt((n+1)/2,a,ans);
    len=1;
    while(len<2*n) len*=2;
    memset(invans,0,sizeof(invans));
    getinv(n,ans,invans);
    for(int i=0;i<n;i++) a1[i]=a[i];
    for(int i=n;i<len;i++) a1[i]=0;
    change(a1),change(invans);
    ntt(a1,1),ntt(invans,1);
    for(int i=0;i<len;i++) a1[i]=a1[i]*invans[i]%mod;
    change(a1),ntt(a1,-1);
    for(int i=0;i<n;i++) ans[i]=(a1[i]+ans[i])*inv2%mod;
    for(int i=n;i<len;i++) ans[i]=0;
}
void getexp(int n,ll a[],ll ans[])
{
    if(n==1){ans[0]=1;return;}
    getexp((n+1)/2,a,ans);
    len=1;
    while(len<2*n) len*=2;
    memset(lnans,0,sizeof(lnans));
    getln(n,ans,lnans);
    for(int i=0;i<n;i++) lnans[i]=(-lnans[i]+a[i]+mod)%mod;
    lnans[0]++;
    change(ans),change(lnans);
    ntt(ans,1),ntt(lnans,1);
    for(int i=0;i<len;i++) ans[i]=ans[i]*lnans[i]%mod;
    change(ans),ntt(ans,-1);
    for(int i=n;i<len;i++) ans[i]=0;
}
void getksm(int n,ll a[],char *s,ll ans[])
{
    ll k1=0,k2=0;
    int len=strlen(s);
    for(int i=0;i<len;i++) k1=(k1*10+s[i]-48)%mod,k2=(k2*10+s[i]-48)%(mod-1);
    ll x=0;
    while(x<n&&!a[x]) x++;
    if(x&&len>=6||k1*x>=n){
        memset(ans,0,sizeof(ans));
        return;
    }
    for(int i=0;i<n-x;i++) a[i]=a[i+x];
    for(int i=n-x;i<n;i++) a[i]=0;
    ll a0=a[0],y=ksm(a0,mod-2);
    for(int i=0;i<n-x*k1;i++) a[i]=a[i]*y%mod;
    getln(n-x*k1,a,a2);
    for(int i=0;i<n-x*k1;i++) a2[i]=a2[i]*k1%mod;
    getexp(n-x*k1,a2,ans);
    y=ksm(a0,k2);
    for(int i=n-1;i>=x*k1;i--) ans[i]=ans[i-x*k1]*y%mod;
    for(int i=0;i<x*k1;i++) ans[i]=0;
}
int main()
{
	scanf("%d%s",&n,s);
	for(int i=0;i<n;i++) scanf("%lld",&a[i]);
    getksm(n,a,s,ans);
    for(int i=0;i<n;i++) printf("%lld ",ans[i]);
}

你可能感兴趣的:(算法)