多项式快速插值学习小记

  • 今天终于抽空把这个综(du)合(liu)知识点学了,心力交瘁……

多项式快速插值

  • 给出 n n n 个点 ( x i , y i ) (x_i,y_i) (xi,yi) ,要求一个次数为 n − 1 n-1 n1 的多项式 F ( x ) F(x) F(x) 满足: F ( x i ) = y i F(x_i)=y_i F(xi)=yi

  • 显然这个多项式是唯一确定的。

  • 根据拉格朗日插值法,我们有: F ( x ) = ∑ i = 1 n ∏ j ̸ = i ( x − x j ) ∏ j ̸ = i ( x i − x j ) y i F(x)=\sum_{i=1}^{n}\frac{\prod_{j\not=i}(x-x_j)}{\prod_{j\not=i}(x_i-x_j)}y_i F(x)=i=1nj̸=i(xixj)j̸=i(xxj)yi

  • 这样我们是可以 O ( n 2 ) O(n^2) O(n2) 求的,考虑优化。

  • 我们先考虑对于每个 i i i ,如何快速得到 ∏ j ̸ = i ( x i − x j ) \prod_{j\not=i}(x_i-x_j) j̸=i(xixj)

  • M ( x ) = ∏ i = 1 n ( x − x i ) M(x)=\prod_{i=1}^{n}(x-x_i) M(x)=i=1n(xxi) ,我们即需求: M ( x ) x − x i \frac{M(x)}{x-x_i} xxiM(x)

  • 根据洛必达法则,当 x x x x i x_i xi 时,分子分母都等于0,一个0/0型,上下求导得: l i m x → x i M ( x ) x − x i = M ′ ( x ) lim_{x→x_i}\frac{M(x)}{x-x_i}=M'(x) limxxixxiM(x)=M(x)

  • 于是我们先分治NTT求出 M ( x ) M(x) M(x) ,再求导得到 M ′ ( x ) M'(x) M(x) ,之后将 x 1 − n x_{1-n} x1n 代入多点求值即可!

  • 这一部分的复杂度是 O ( n   l o g 2 n ) O(n\ log^2n) O(n log2n) 的。

  • 求导、多点求值这些前置知识我的博客里都有讲:

    多项式的求逆、取模和多点求值学习小记

    多项式的ln、exp和快速幂学习小记

  • 又设 V i = y i ∏ j ̸ = i ( x i − x j ) V_i=\frac{y_i}{\prod_{j\not=i}(x_i-x_j)} Vi=j̸=i(xixj)yi ,则此时 V i V_i Vi 已知,我们要求: F ( x ) = ∑ i = 1 n V i ∏ j ̸ = i ( x − x j ) F(x)=\sum_{i=1}^{n}V_i\prod_{j\not=i}(x-x_j) F(x)=i=1nVij̸=i(xxj)

  • 还是分治NTT,设 L ( x ) = ∑ i = 1 n / 2 ( x − x i ) L(x)=\sum_{i=1}^{n/2}(x-x_i) L(x)=i=1n/2(xxi) R ( x ) = ∑ i = n / 2 + 1 n ( x − x i ) R(x)=\sum_{i=n/2+1}^{n}(x-x_i) R(x)=i=n/2+1n(xxi) ,则有: F ( x ) = ∑ i = 1 n / 2 V i ∏ j ̸ = i , 1 ≤ j ≤ n / 2 ( x − x j ) R ( x ) + ∑ i = n / 2 + 1 n V i ∏ j ̸ = i , n / 2 + 1 ≤ j ≤ n ( x − x j ) L ( x ) F(x)=\sum_{i=1}^{n/2}V_i\prod_{j\not=i,1\leq j\leq n/2}(x-x_j)R(x)+\sum_{i=n/2+1}^{n}V_i\prod_{j\not=i,n/2+1\leq j\leq n}(x-x_j)L(x) F(x)=i=1n/2Vij̸=i,1jn/2(xxj)R(x)+i=n/2+1nVij̸=i,n/2+1jn(xxj)L(x)

  • 递归即可求得,递归底层的 F ( x ) F(x) F(x) 就是 V i V_i Vi

  • 还有就是这里的 L ( x ) 、 R ( x ) L(x)、R(x) L(x)R(x) 在多点求值中已经算过了,不用再算一遍啦。

  • 这一部分的复杂度也是 O ( n   l o g 2 n ) O(n\ log^2n) O(n log2n)

  • 故总时间复杂度即为 O ( n   l o g 2 n ) O(n\ log^2n) O(n log2n) ,常数很大很大。

  • 模板题:洛谷 P5158 【模板】多项式快速插值

Code

#include
#include
#include
using namespace std;
typedef long long LL;
const int N=1e5+5,M=18,G=3,mo=998244353;
int tot;
int xx[N],yy[N],val[N];
int a[N],b[N],c[N],rr[N];//a=b*c+rr
int ra[N],rb[N<<2],irb[N<<2];
int f[N*M<<1],stf[N<<2],enf[N<<2];
int g[N*M<<1],stg[N<<2],eng[N<<2];
int h[N],sth[N],enh[N],f3[N];
int f1[N<<2],f2[N<<2],wn[N<<2],rev[N<<2];
inline int read()
{
	int X=0,w=0; char ch=0;
	while(!isdigit(ch)) w|=ch=='-',ch=getchar();
	while(isdigit(ch)) X=(X<<3)+(X<<1)+(ch^48),ch=getchar();
	return w?-X:X;
}
void write(int x)
{
	if(x>9) write(x/10);
	putchar(x%10+'0');
}
inline int ksm(int x,int y)
{
	int s=1;
	while(y)
	{
		if(y&1) s=(LL)s*x%mo;
		x=(LL)x*x%mo;
		y>>=1;
	}
	return s;
}
inline void NTT(int *y,int len,int ff)
{
	for(int i=0;i<len;i++)
		if(i<rev[i]) swap(y[i],y[rev[i]]);
	for(int h=2,d=len>>1;h<=len;h<<=1,d>>=1)
		for(int i=0,k=h>>1;i<len;i+=h)
			for(int j=0,cnt=0;j<k;j++,cnt+=d)
			{
				int u=y[i+j],t=(LL)wn[cnt]*y[i+j+k]%mo;
				y[i+j]=u+t>=mo?u+t-mo:u+t;
				y[i+j+k]=u-t<0?u-t+mo:u-t;
			}
	if(ff==-1)
	{
		reverse(y+1,y+len);
		int inv=ksm(len,mo-2);
		for(int i=0;i<len;i++) y[i]=(LL)y[i]*inv%mo;
	}
}
void make(int v,int l,int r)
{
	if(l==r)
	{
		g[stg[v]=++tot]=mo-xx[l];
		g[eng[v]=++tot]=1;
		return;
	}
	int mid=l+r>>1,ls=v<<1,rs=ls|1;
	make(ls,l,mid),make(rs,mid+1,r);
	int na=eng[ls]-stg[ls]+1,nb=eng[rs]-stg[rs]+1;
	int len=1,ll=0;
	while(len<na+nb) len<<=1,ll++;
	for(int i=0;i<len;i++) rev[i]=rev[i>>1]>>1|(i&1)<<ll-1;
	int w0=ksm(G,(mo-1)/len);
	for(int i=wn[0]=1;i<=len;i++) wn[i]=(LL)wn[i-1]*w0%mo;
	for(int i=0;i<na;i++) f1[i]=g[stg[ls]+i];
	for(int i=na;i<len;i++) f1[i]=0;
	for(int i=0;i<nb;i++) f2[i]=g[stg[rs]+i];
	for(int i=nb;i<len;i++) f2[i]=0;
	NTT(f1,len,1),NTT(f2,len,1);
	for(int i=0;i<len;i++) f1[i]=(LL)f1[i]*f2[i]%mo;
	NTT(f1,len,-1);
	stg[v]=tot+1;
	na+=nb-1;
	for(int i=0;i<na;i++) g[++tot]=f1[i];
	eng[v]=tot;
}
void getinv(int len,int ll)
{
	if(len==1)
	{
		irb[0]=ksm(rb[0],mo-2);
		return;
	}
	getinv(len>>1,ll-1);
	for(int i=0;i<len;i++) rev[i]=rev[i>>1]>>1|(i&1)<<ll-1;
	int w0=ksm(G,(mo-1)/len);
	for(int i=wn[0]=1;i<=len;i++) wn[i]=(LL)wn[i-1]*w0%mo;
	for(int i=0;i<len>>1;i++) f1[i]=rb[i];
	for(int i=len>>1;i<len;i++) f1[i]=0;
	NTT(f1,len,1),NTT(irb,len,1);
	for(int i=0;i<len;i++) irb[i]=(2-(LL)f1[i]*irb[i]%mo+mo)*irb[i]%mo;
	NTT(irb,len,-1);
	for(int i=len>>1;i<len;i++) irb[i]=0;
}
void solve(int v,int l,int r,int fa)
{
	int na=enf[fa]-stf[fa],nb=eng[v]-stg[v];
	if(na>=nb)
	{
		int nc=na-nb;
		for(int i=0;i<=na;i++) a[i]=f[stf[fa]+i];
		for(int i=0;i<=nb;i++) b[i]=g[stg[v]+i];
		for(int i=0;i<=nc;i++) ra[i]=a[na-i];
		for(int i=0;i<=nb;i++) rb[i]=b[nb-i];
		for(int i=nc+1;i<=nb;i++) rb[i]=0;
		int len=1,ll=0;
		while(len<=nc*2+1) len<<=1,ll++;
		for(int i=nb+1;i<len;i++) rb[i]=0;
		for(int i=0;i<len;i++) irb[i]=0,f1[i]=0;
		getinv(len,ll);
		for(int i=0;i<=nc;i++) f1[i]=ra[i],f2[i]=irb[i];
		for(int i=nc+1;i<len;i++) f1[i]=f2[i]=0;
		NTT(f1,len,1),NTT(f2,len,1);
		for(int i=0;i<len;i++) f1[i]=(LL)f1[i]*f2[i]%mo;
		NTT(f1,len,-1);
		for(int i=0;i<=nc;i++) c[nc-i]=f1[i];
		for(int i=nc+1;i<nb;i++) c[i]=0;
		len=1,ll=0;
		while(len<nb<<1) len<<=1,ll++;
		for(int i=0;i<len;i++) rev[i]=rev[i>>1]>>1|(i&1)<<ll-1;
		int w0=ksm(G,(mo-1)/len);
		for(int i=wn[0]=1;i<=len;i++) wn[i]=(LL)wn[i-1]*w0%mo;
		for(int i=0;i<nb;i++) f1[i]=b[i],f2[i]=c[i];
		for(int i=nb;i<len;i++) f1[i]=0,f2[i]=0;
		NTT(f1,len,1),NTT(f2,len,1);
		for(int i=0;i<len;i++) f1[i]=(LL)f1[i]*f2[i]%mo;
		NTT(f1,len,-1);
		for(int i=0;i<nb;i++) rr[i]=(a[i]-f1[i]+mo)%mo;
		while(nb>1 && !rr[nb-1]) nb--;
		stf[v]=tot+1;
		for(int i=0;i<nb;i++) f[++tot]=rr[i];
		enf[v]=tot;
	}else
	{
		stf[v]=tot+1;
		for(int i=stf[fa];i<=enf[fa];i++) f[++tot]=f[i];
		enf[v]=tot;
	}
	if(l==r)
	{
		val[l]=f[stf[v]];
		return;
	}
	int mid=l+r>>1;
	solve(v<<1,l,mid,v);
	solve(v<<1|1,mid+1,r,v);
}
void work(int v,int l,int r)
{
	if(l==r) return;
	int mid=l+r>>1,ls=v<<1,rs=ls|1;
	work(ls,l,mid);
	work(rs,mid+1,r);
	int na=enh[l]-sth[l]+1,nb=eng[rs]-stg[rs]+1;
	int len=1,ll=0;
	while(len<na+nb) len<<=1,ll++;
	for(int i=0;i<len;i++) rev[i]=rev[i>>1]>>1|(i&1)<<ll-1;
	int w0=ksm(G,(mo-1)/len);
	for(int i=wn[0]=1;i<=len;i++) wn[i]=(LL)wn[i-1]*w0%mo;
	for(int i=0;i<na;i++) f1[i]=h[sth[l]+i];
	for(int i=na;i<len;i++) f1[i]=0;
	for(int i=0;i<nb;i++) f2[i]=g[stg[rs]+i];
	for(int i=nb;i<len;i++) f2[i]=0;
	NTT(f1,len,1),NTT(f2,len,1);
	for(int i=0;i<len;i++) f1[i]=(LL)f1[i]*f2[i]%mo;
	NTT(f1,len,-1);
	na+=nb-1;
	for(int i=0;i<na;i++) f3[i]=f1[i];
	for(int i=na;i<len;i++) f3[i]=0;
	
	na=enh[mid+1]-sth[mid+1]+1,nb=eng[ls]-stg[ls]+1;
	for(int i=0;i<na;i++) f1[i]=h[sth[mid+1]+i];
	for(int i=na;i<len;i++) f1[i]=0;
	for(int i=0;i<nb;i++) f2[i]=g[stg[ls]+i];
	for(int i=nb;i<len;i++) f2[i]=0;
	NTT(f1,len,1),NTT(f2,len,1);
	for(int i=0;i<len;i++) f1[i]=(LL)f1[i]*f2[i]%mo;
	NTT(f1,len,-1);
	na+=nb-1;
	for(int i=0;i<na;i++) h[sth[l]+i]=f3[i]+f1[i]>=mo?f3[i]+f1[i]-mo:f3[i]+f1[i];
	enh[l]=sth[l]+na-1;
}
int main()
{
	int n=read();
	for(int i=1;i<=n;i++) xx[i]=read(),yy[i]=read();
	make(1,1,n);
	int m=eng[1]-stg[1];
	for(int i=0;i<=m;i++) f1[i]=g[stg[1]+i];
	for(int i=0;i<m;i++) f1[i]=(LL)f1[i+1]*(i+1)%mo;
	f1[m--]=0;
	stf[tot=0]=1;
	for(int i=0;i<=m;i++) f[enf[0]=++tot]=f1[i];
	solve(1,1,n,0);
	for(int i=1;i<=n;i++) val[i]=(LL)yy[i]*ksm(val[i],mo-2)%mo;
	tot=0;
	for(int i=1;i<=n;i++)
	{
		sth[i]=enh[i]=++tot;
		h[tot]=val[i];
	}
	work(1,1,n);
	for(int i=0;i<n;i++) write(h[sth[1]+i]),putchar(' ');
	return 0;
}

你可能感兴趣的:(NTT,模板与算法)