[loj6391][THUPC2018]淘米神的树(Tommy)

前言

经典板子应用题

题目相关

链接

题目大意

现在有一个 n n n个节点的树,初始 n n n个节点有 n − 2 n-2 n2个是白色的, 2 2 2个是黑色的
每次可以将一个黑点染红,并将相邻的白点同时染黑
问把整棵树染红的方法数,答案模 998244353 998244353 998244353

数据范围

n ≤ 234567 n\le234567 n234567

题解

两个点的情况比较麻烦
我们先考虑一开始只有一个黑点的情况
我们把一开始的黑点当作根节点建树
对于每一个节点 u u u,设 u u u节点及其子树的大小为 s i z e u size_u sizeu,我们分析其子树对答案的贡献
其子树肯定要选 u u u节点本身
剩下的节点如果没有限制,那么方案数为 ( s i z e u − 1 ) ! (size_u-1)! (sizeu1)!
然后我们考虑限制,我们发现每个子树内部有其限制
那么总方案数为
a n s u = ( s i z e u − 1 ) ! ∏ f a v = u a n s v s i z e v ! ans_u=(size_u-1)!\prod_{fa_v=u}\frac{ans_v}{size_v!} ansu=(sizeu1)!fav=usizev!ansv
进行整理
a n s u = ( s i z e u − 1 ) ! ∏ f a v = u s i z e v ! ∏ f a v = u a n s v ans_u=\frac{(size_u-1)!}{\prod_{fa_v=u}size_v!}\prod_{fa_v=u}{ans_v} ansu=fav=usizev!(sizeu1)!fav=uansv
然后容易发现
a n s r o o t = ∏ u = 1 n ( s i z e u − 1 ) ! ∏ f a v = u s i z e v ! = ∏ u = 1 n ( s i z e u − 1 ) ! ∏ u = 1 , u ≠ r o o t n s i z e u ! = ( s i z e r o o t − 1 ) ! ∏ u = 1 , u ≠ r o o t n ( s i z e u − 1 ) ! s i z e u ! = ( s i z e r o o t − 1 ) ! ∏ u = 1 , u ≠ r o o t n 1 s i z e u = s i z e r o o t ! ∏ u = 1 n 1 s i z e u = n ! ∏ u = 1 n s i z e u \begin{aligned} ans_{root}&=\prod_{u=1}^n\frac{(size_u-1)!}{\prod_{fa_v=u}size_v!}\\ &=\frac{\prod_{u=1}^n(size_u-1)!}{\prod_{u=1,u\ne root}^nsize_u!}\\ &=(size_{root}-1)!\prod_{u=1,u\ne root}^n\frac{(size_u-1)!}{size_u!}\\ &=(size_{root}-1)!\prod_{u=1,u\ne root}^n\frac{1}{size_u}\\ &=size_{root}!\prod_{u=1}^n\frac{1}{size_u}\\ &=\frac{n!}{\prod_{u=1}^nsize_u}\\ \end{aligned} ansroot=u=1nfav=usizev!(sizeu1)!=u=1,u̸=rootnsizeu!u=1n(sizeu1)!=(sizeroot1)!u=1,u̸=rootnsizeu!(sizeu1)!=(sizeroot1)!u=1,u̸=rootnsizeu1=sizeroot!u=1nsizeu1=u=1nsizeun!
然后考虑两个点的情况
设初始的两个黑点为 a , b a,b a,b,我们新加一个点 s s s,将 s s s连向 a a a b b b,并且初始设为只有 s s s为黑点,那么问题就转化成环套树上求答案
我们的思路是枚举环上染红的最后一个点的位置,但这样并不能直接做
所以我们枚举环上的边并将其删除使图成为树并计算答案
我们发现,这么计算会出现重复的情况,对于一个方案,我们找到这个方案在环上的最后被染红的点,我们发现其被环上相邻两点染黑的情况各一次,即所有方案被算到两次,所以答案要乘以 1 2 \frac12 21
对于非环上的点我们可以进行预处理,对于节点 s s s也可以直接预处理,设这些点的 s i z e size size和为 Z Z Z
对于一个环上的点 u u u,设 a u a_u au为其所有的确定子树大小,即一加所有非环上儿子的 s i z e size size,设 b u b_u bu为对于某种断环方式下其子树的 s i z e size size
t t t为某种断环方式下的环上子树大小和即 t = ∑ b u t=\sum b_u t=bu
我们观察一下 t t t的取值情况
我们假设环上的点有 k + 1 k+1 k+1个,设 s s s点编号为0,其余的点依次沿环编号为1 ~ k
假设断的边是i和i+1
我们发现其实 b b b就是1 ~ i里的 a a a的后缀和,i+1 ~ k里的 a a a的前缀和
我们对 a a a做一遍前缀和成为 c c c,我们发现 ∏ b = ∏ j ≠ i ∣ c i − c j ∣ \prod b=\prod_{j\ne i}|c_i-c_j| b=j̸=icicj(注意 c 0 = 0 c_0=0 c0=0
将绝对值去掉(直接分类讨论即可)后,问题就和多项式快速插值里的一样了
摘一下结论:
R ( x ) = ∏ j = 1 n ( x − x j ) R(x)=\prod_{j=1}^n(x-x_j) R(x)=j=1n(xxj)
∏ j = 1 , i ≠ j n ( x i − x j ) = R ′ ( x i ) \prod_{j=1,i\neq j}^n(x_i-x_j)=R'(x_i) j=1,i̸=jn(xixj)=R(xi)
这样的话多项式多点求值即可
复杂度 O ( n l o g 2 n ) \mathcal O(nlog^2n) O(nlog2n)

代码

本代码使用了板子

#include
#include
#include
#include
#include
namespace fast_IO
{
    const int IN_LEN=10000000,OUT_LEN=10000000;
    char ibuf[IN_LEN],obuf[OUT_LEN],*ih=ibuf+IN_LEN,*oh=obuf,*lastin=ibuf+IN_LEN,*lastout=obuf+OUT_LEN-1;
    inline char getchar_(){return (ih==lastin)&&(lastin=(ih=ibuf)+fread(ibuf,1,IN_LEN,stdin),ih==lastin)?EOF:*ih++;}
    inline void putchar_(const char x){if(oh==lastout)fwrite(obuf,1,oh-obuf,stdout),oh=obuf;*oh++=x;}
    inline void flush(){fwrite(obuf,1,oh-obuf,stdout);}
}
using namespace fast_IO;
#define getchar() getchar_()
#define putchar(x) putchar_((x))
typedef long long ll;
#define rg register
template <typename T> inline T max(const T a,const T b){return a>b?a:b;}
template <typename T> inline T min(const T a,const T b){return a<b?a:b;}
template <typename T> inline T mind(T&a,const T b){a=a<b?a:b;}
template <typename T> inline T maxd(T&a,const T b){a=a>b?a:b;}
template <typename T> inline T abs(const T a){return a>0?a:-a;}
template <typename T> inline void swap(T&a,T&b){T c=a;a=b;b=c;}
template <typename T> inline void swap(T*a,T*b){T c=a;a=b;b=c;}
template <typename T> inline T gcd(const T a,const T b){if(!b)return a;return gcd(b,a%b);}
template <typename T> inline T square(const T x){return x*x;};
template <typename T> inline void read(T&x)
{
    char cu=getchar();x=0;bool fla=0;
    while(!isdigit(cu)){if(cu=='-')fla=1;cu=getchar();}
    while(isdigit(cu))x=x*10+cu-'0',cu=getchar();
    if(fla)x=-x;  
}
template <typename T> void printe(const T x)
{
    if(x>=10)printe(x/10);
    putchar(x%10+'0');
}
template <typename T> inline void print(const T x)
{
    if(x<0)putchar('-'),printe(-x);
    else printe(x);
}
const int maxn=2097152,mod=998244353;
inline int Md(const int x){return x>=mod?x-mod:x;}
template<typename T>
inline int pow(int x,T y)
{
    rg int res=1;x%=mod;
    for(;y;y>>=1,x=(ll)x*x%mod)if(y&1)res=(ll)res*x%mod;
    return res;
}
namespace Poly///////namespace of Poly
{
int W_[maxn],ha[maxn],hb[maxn],Inv[maxn];
inline void init(const int x)
{
    rg int tim=0,lenth=1;
    while(lenth<x)lenth<<=1,tim++;
    for(rg int i=1;i<lenth;i<<=1)
    {
    	const int WW=pow(3,(mod-1)/(i*2));
    	W_[i]=1;
    	for(rg int j=i+1,k=i<<1;j<k;j++)W_[j]=(ll)W_[j-1]*WW%mod;
    }
    Inv[0]=Inv[1]=1;
    for(rg int i=2;i<x;i++)Inv[i]=(ll)(mod-mod/i)*Inv[mod%i]%mod;
}
int L;
inline void DFT(int*A)//prepare:init L 
{
    for(rg int i=0,j=0;i<L;i++)
    {
        if(i>j)swap(A[i],A[j]);
        for(rg int k=L>>1;(j^=k)<k;k>>=1);
    }
    for(rg int i=1;i<L;i<<=1)
        for(rg int j=0,J=i<<1;j<L;j+=J)
            for(rg int k=0;k<i;k++)
            {
                const int x=A[j+k],y=(ll)A[j+k+i]*W_[i+k]%mod;
                A[j+k]=Md(x+y),A[j+k+i]=Md(mod+x-y);
            }
}
void IDFT(int*A)
{
    for(rg int i=1;i<L-i;i++)swap(A[i],A[L-i]);
    DFT(A);
}
inline int Quadratic_residue(const int a)
{
    if(a==0)return 0;
    int b=(rand()<<14^rand())%mod;
    while(pow(b,(mod-1)>>1)!=mod-1)b=(rand()<<14^rand())%mod;
    int s=mod-1,t=0,x,inv=pow(a,mod-2),f=1;
    while(!(s&1))s>>=1,t++,f<<=1;
    t--,x=pow(a,(s+1)>>1),f>>=1;
    while(t)
    {
        f>>=1;
        if(pow((int)((ll)inv*x%mod*x%mod),f)!=1)x=(ll)x*pow(b,s)%mod;
        t--,s<<=1;
    }
    return min(x,mod-x);
}
struct poly
{
    std::vector<int>A;
    poly(){A.resize(0);}
    poly(const int x){A.resize(1),A[0]=x;}
    inline int&operator[](const int x){return A[x];}
    inline int operator[](const int x)const{return A[x];}
    inline void clear(){A.clear();}
    inline unsigned int size()const{return A.size();}
    inline void resize(const unsigned int x){A.resize(x);}
    void RE(const int x)
    {
        A.resize(x);
        for(rg int i=0;i<x;i++)A[i]=0; 
    }
    void readin(const int MAX)
    {
        A.resize(MAX);
        for(rg int i=0;i<MAX;i++)read(A[i]);
    }
    void putout()const
    {
        for(rg unsigned int i=0;i<A.size();i++)print(A[i]),putchar(' ');
    }
    inline poly operator +(const poly b)const
    {
        poly RES;
        RES.resize(max(size(),b.size()));
        for(rg unsigned int i=0;i<RES.size();i++)RES[i]=Md((i<size()?A[i]:0)+(i<b.size()?b[i]:0));
        return RES;
    }
    inline poly operator -(const poly b)const
    {
        poly RES;
        RES.resize(max(size(),b.size()));
        for(rg unsigned int i=0;i<RES.size();i++)RES[i]=Md((i<size()?A[i]:0)+mod-(i<b.size()?b[i]:0));
        return RES;
    }
    inline poly operator *(const int b)const
    {
        poly RES=*this;
        for(rg unsigned int i=0;i<RES.size();i++)RES[i]=(ll)RES[i]*b%mod;
        return RES;
    }
    inline poly operator /(const int b)const
    {
        poly RES=(*this)*pow(b,mod-2);
    	return RES;
    }
    inline poly operator *(const poly b)const
    {
        const int RES=A.size()+b.size()-1;
        L=1;while(L<RES)L<<=1;
        poly c;c.A.resize(RES);
        memset(ha,0,L<<2);
        memset(hb,0,L<<2);
        for(rg unsigned int i=0;i<A.size();i++)ha[i]=A[i];
        for(rg unsigned int i=0;i<b.A.size();i++)hb[i]=b.A[i];
        DFT(ha),DFT(hb);
        for(rg int i=0;i<L;i++)ha[i]=(ll)ha[i]*hb[i]%mod;
        IDFT(ha);
        const int inv=pow(L,mod-2);
        for(rg int i=0;i<RES;i++)c.A[i]=(ll)ha[i]*inv%mod;
        return c;
    }
    inline poly inv()const
    {
        poly C;
        if(A.size()==1){C=*this;C[0]=pow(C[0],mod-2);return C;}
        C.resize((A.size()+1)>>1);
        for(rg unsigned int i=0;i<C.size();i++)C[i]=A[i];
        C=C.inv();
        L=1;while(L<(int)size()*2-1)L<<=1;
        for(rg unsigned int i=0;i<A.size();i++)ha[i]=A[i];
        for(rg unsigned int i=0;i<C.size();i++)hb[i]=C[i];
        memset(ha+A.size(),0,(L-A.size())<<2);
        memset(hb+C.size(),0,(L-C.size())<<2);
        DFT(ha),DFT(hb);
        for(rg int i=0;i<L;i++)ha[i]=(ll)(2-(ll)hb[i]*ha[i]%mod+mod)*hb[i]%mod;
        IDFT(ha);
        const int inv=pow(L,mod-2);
        C.resize(size());
        for(rg unsigned int i=0;i<size();i++)C[i]=(ll)ha[i]*inv%mod;
        return C;
    }
/*    inline poly inv()const
    {
        poly C;
        if(A.size()==1){C=*this;C[0]=pow(C[0],mod-2);return C;}
        C.resize((A.size()+1)>>1);
        for(rg unsigned int i=0;i//大常数版本 
    inline void Reverse(const int n)
    {
    	A.resize(n);
    	for(rg int i=0,j=n-1;i<j;i++,j--)swap(A[i],A[j]);
    }
    inline poly operator /(const poly B)const
    {
    	if(size()<B.size())return 0;
        poly a=*this,b=B;a.Reverse(size()),b.Reverse(B.size());
        b.resize(size()-B.size()+1);
        b=b.inv();
        b=b*a;
        b.Reverse(size()-B.size()+1);
        return b;
    }
    inline poly operator %(const poly B)const
    {
        poly C=(*this)-(*this)/B*B;C.resize(B.size()-1);
        return C;
    }
    inline poly diff()const
    {
        poly C;C.resize(size()-1);
        for(rg unsigned int i=1;i<size();i++)C[i-1]=(ll)A[i]*i%mod;
        return C;
    }
    inline poly inte()const
    {
        poly C;C.resize(size()+1);
        for(rg unsigned int i=0;i<size();i++)C[i+1]=(ll)A[i]*Inv[i+1]%mod;
        C[0]=0;
        return C;
    }
    inline poly ln()const
    {
        poly C=(diff()*inv()).inte();C.resize(size());
        return C;
    }
    inline poly exp()const
    {
        poly C;
        if(size()==1){C=*this;C[0]=1;return C;}
        C.resize((size()+1)>>1);
        for(rg unsigned int i=0;i<C.size();i++)C[i]=A[i];
        C=C.exp();C.resize(size());
        poly D=(poly)1-C.ln()+*this;
        D=D*C;
        D.resize(size());
        return D;
    }
    inline poly sqrt()const
    {
        poly C;
        if(size()==1)
        {
            C=*this;C[0]=Quadratic_residue(C[0]);
            return C;
        }
        C.resize((size()+1)>>1);
        for(rg unsigned int i=0;i<C.size();i++)C[i]=A[i];
        C=C.sqrt();C.resize(size());
        C=(C+*this*C.inv())*(int)499122177;
        C.resize(size());
        return C;
    }
    inline poly operator >>(const unsigned int x)const
    {
    	poly C;if(size()<x){C.resize(0);return C;}
        C.resize(size()-x);
    	for(rg unsigned int i=0;i<C.size();i++)C[i]=A[i+x];
    	return C;
    }
    inline poly operator <<(const unsigned int x)const
    {
    	poly C;C.RE(size()+x);
    	for(rg unsigned int i=0;i<size();i++)C[i+x]=A[i];
    	return C;
    }
    inline poly Pow(const unsigned int x)const
    {
    	for(rg unsigned int i=0;i<size();i++)
            if(A[i])
            {
                poly C=((((*this/A[i])>>i).ln()*x).exp()*pow(A[i],x))<<(min(i*x,size()));
                C.resize(size());
                return C;
            }
    	return *this;
    }
    inline void cheng(const poly&B)
    {
        for(rg unsigned int i=0;i<size();i++)A[i]=(ll)A[i]*B[i]%mod; 
    }
    inline void jia(const poly&B)
    {
        for(rg unsigned int i=0;i<size();i++)A[i]=Md(A[i]+B[i]); 
    }
    inline void dft()
    {
        memset(ha,0,L<<2);
        for(rg unsigned int i=0;i<A.size();i++)ha[i]=A[i];
        DFT(ha);
        resize(L);
        for(rg int i=0;i<L;i++)A[i]=ha[i];
    }
    inline void idft()
    {
        memset(ha,0,L<<2);
        for(rg unsigned int i=0;i<A.size();i++)ha[i]=A[i];
        IDFT(ha);
        const int inv=pow(L,mod-2);
        for(rg int i=0;i<L;i++)A[i]=(ll)ha[i]*inv%mod;
        while(size()&&!A[size()-1])A.pop_back();
    }
};
void fz(const int root,const int l,const int r,std::vector<int>&v,std::vector<poly>&A)
{
    if(l==r)
    {
        A[root].resize(2);
        A[root][0]=(mod-v[l])%mod;
        A[root][1]=1;
        return;
    }
    const int mid=(l+r)>>1;
    fz(root<<1,l,mid,v,A),fz(root<<1|1,mid+1,r,v,A);
    A[root]=A[root<<1]*A[root<<1|1];
}
void calc(const int root,const int l,const int r,std::vector<int>&v,std::vector<poly>&A,std::vector<poly>&B)
{
    if(l==r)
    {
        v[l]=B[root][0];
        return;
    }
    const int mid=(l+r)>>1;
    B[root<<1]=B[root]%A[root<<1];
    B[root<<1|1]=B[root]%A[root<<1|1];
    calc(root<<1,l,mid,v,A,B),calc(root<<1|1,mid+1,r,v,A,B);
}
void multi_point_evaluation(const poly a,std::vector<int>&v)
{
    std::vector<poly>A,B;A.resize(maxn),B.resize(maxn);
    fz(1,0,v.size()-1,v,A);
    B[1]=a%A[1];
    calc(1,0,v.size()-1,v,A,B);
}
void fz2(const int root,const int l,const int r,std::vector<int>&y,std::vector<poly>&A,std::vector<poly>&B)
{
    if(l==r)
    {
        B[root].resize(1),B[root][0]=y[l];
        return;
    }
    const int mid=(l+r)>>1;
    fz2(root<<1,l,mid,y,A,B),fz2(root<<1|1,mid+1,r,y,A,B);
    B[root]=B[root<<1]*A[root<<1|1]+B[root<<1|1]*A[root<<1];
}
poly interpolation(std::vector<int>&x,std::vector<int>&y)
{
    std::vector<poly>A,B;A.resize(maxn),B.resize(maxn);
	fz(1,0,x.size()-1,x,A);
	multi_point_evaluation(A[1].diff(),x);
	for(rg unsigned int i=0;i<x.size();i++)y[i]=(ll)y[i]*pow(x[i],mod-2)%mod;
    fz2(1,0,x.size()-1,y,A,B);
    return B[1];
}
}///////namespace of Poly
int n,a,b;
int head[maxn],nxt[maxn],tow[maxn],tmp,fa[maxn],size[maxn];
inline void addb(const int u,const int v)
{
	tmp++;
	nxt[tmp]=head[u];
	head[u]=tmp;
	tow[tmp]=v;
}
void dfs1(const int u)
{
	size[u]=1;
	for(rg int i=head[u];i;i=nxt[i])
	{
		const int v=tow[i];
		if(v!=fa[u])fa[v]=u,dfs1(v),size[u]+=size[v];
	}
}
int stack[maxn],top,c[maxn],Z,fac=1,ans;bool is[maxn];
std::vector<int>C;
std::vector<Poly::poly>A;
int main()
{
    Poly::init(maxn);///////namespace of Poly
    read(n),read(a),read(b);
    for(rg int i=1;i<n;i++)
    {
    	int u,v;read(u),read(v);
    	addb(u,v),addb(v,u);
    }
    dfs1(a);
    while(b!=a)stack[++top]=b,is[b]=1,b=fa[b];
    stack[++top]=b,is[b]=1,b=fa[b];
    for(rg int i=1;i<=top;i++)
	{
		const int u=stack[i];
		c[i]=1;
		for(rg int j=head[u];j;j=nxt[j])
		{
			const int v=tow[j];
			if(!is[v])c[i]+=size[v];
		}
	}
    Z=n+1;
	for(rg int i=1;i<=n;i++)
	{
		fac=(ll)fac*i%mod;
		if(!is[i])Z=(ll)Z*size[i]%mod;
	}
	fac=(ll)fac*(n+1)%mod;
	C.push_back(0);
	for(rg int i=1;i<=top;i++)c[i]=Md(c[i]+c[i-1]),C.push_back(c[i]);
	A.resize(maxn);
	Poly::fz(1,0,top,C,A);
	Poly::multi_point_evaluation(A[1].diff(),C);
	for(rg int i=0;i<=top;i++)
		if((top-i)&1)ans=Md(ans+mod-(ll)fac*pow((ll)C[i]*Z%mod,mod-2)%mod);
		else ans=Md(ans+(ll)fac*pow((ll)C[i]*Z%mod,mod-2)%mod);
	print((ll)ans*499122177%mod);
	return flush(),0;
}

总结

先是要讲方案数通过类似递归式的形式写出来然后化简
然后用一个很好的加点思路
代码的话差不多就是板子了

你可能感兴趣的:(OI,NTT,多项式多点求值)