Codechef CUTTREE 任意模数FFT+点分治

题意

n-1次删除操作,每次随机删掉一条边,一个联通块的价值是这个联通块的大小的平方,问期望价值和, n105 n ≤ 10 5

分析

这道题精神AC。。。。有两个点被卡了

首先一个联通块的价值可以理解为任意两点连通所以经过的边数,也就是任意两点间的距离和

ansi=(n1dist(x,y))i(n1)i a n s i = ( n − 1 − d i s t ( x , y ) ) i _ ( n − 1 ) i _

然后展开
ansi=1(n1)!d=0n1(n1d)!cnt[d]1(n1id)!(n1i)! a n s i = 1 ( n − 1 ) ! ∑ d = 0 n − 1 ( n − 1 − d ) ! c n t [ d ] 1 ( n − 1 − i − d ) ! ( n − 1 − i ) !

cnt[d] c n t [ d ] 表示距离为d的个数,算距离的时候普通FFT一下, 然后最后拆系数FFT一下

可能我的拆系数不是很优美,我是写7次DFT的。。。

代码

#include 

#define pb push_back
#define cl clear
#define MP make_pair
#define bin(i) (1<<(i))

using namespace std;

typedef long long ll;

const ll N = 400010;
const ll Mod = 1e9+7;

inline ll read()
{
  ll p=0; ll f=1; char ch=getchar();
  while(ch<'0' || ch>'9'){if(ch=='-') f=-1; ch=getchar();}
  while(ch>='0' && ch<='9'){p=p*10+ch-'0'; ch=getchar();}
  return p*f;
}

void upd(ll &x,ll y){if(x>y) x=y;}
struct E{ll x,y,next;}edge[N]; ll len,first[N];
void ins(ll x,ll y){len++; edge[len].x=x; edge[len].y=y; edge[len].next=first[x]; first[x]=len;}

struct node
{
  long double r,i;
  node(){}
  node(long double _r,long double _i){r=_r; i=_i;}
  friend node operator + (const node &x,const node &y){return node(x.r+y.r,x.i+y.i);}
  friend node operator - (const node &x,const node &y){return node(x.r-y.r,x.i-y.i);}
  friend node operator * (const node &x,const node &y){return node(x.r*y.r-x.i*y.i,x.r*y.i+y.r*x.i);}
}wn[N];

long double pi = acos(-1);

ll R[N];
void DFT(node *a,ll n,ll op)
{
  for(int i=1;iif(R[i] > i) swap(a[R[i]],a[i]);
  for(int i=1;i1)
  {
    for(int j=0;j1)
    {
      node w = node(1,0);
      for(int k=0;k*wn[i])
      {
        node x = a[j+k]; node y = a[j+k+i]*w;
        a[j+k] = x+y; a[j+k+i] = x-y;
      }
    }
  }
  if(op == -1) reverse(a+1,a+n);
}

node c[N],d[N]; ll cnt[N];

void FFT(ll *a,ll *b,ll n,ll m)
{
  ll mm = n+m; ll l=0,nn; for(nn=1;nn<=mm;nn<<=1) l++;
  for(int i=0;i<=n;i++) c[i].r = a[i],c[i].i = 0;
  for(int i=0;i<=m;i++) d[i].r = b[i],d[i].i = 0;
  for(int i=1;i>1] >> 1) | ((i&1) << (l-1));
  DFT(c,nn,1); DFT(d,nn,1);
  for(int i=0;i1);
  for(int i=0;i0.5) * 2) % Mod;
  for(int i=0;i0;

}

ll a[N],b[N]; ll siz[N],p=0,mx; bool vis[N]; ll MAX_DEP = 0;
void dfs2(ll x,ll f,ll dis)
{
  siz[x] = 1;
  for(int k=first[x];k!=-1;k=edge[k].next)
  {
    ll y = edge[k].y;
    if(y==f || vis[y]) continue;
    dfs2(y,x,dis+1); siz[x] += siz[y];
  }
  MAX_DEP = max(MAX_DEP , dis);
}

void Find_root(ll x,ll f,ll tot)
{
  ll maxx = tot - siz[x];
  for(int k=first[x];k!=-1;k=edge[k].next)
  {
    ll y = edge[k].y;
    if(y==f || vis[y]) continue;
    Find_root(y,x,tot); maxx = max(maxx , siz[y]);
  }
  if(maxx < mx)
  {
    mx = maxx; p = x;
  }
}

void calc(ll x,ll f,ll dd)
{
  b[dd]++;
  for(int k=first[x];k!=-1;k=edge[k].next)
  {
    ll y = edge[k].y;
    if(vis[y] || y==f) continue;
    calc(y,x,dd+1);
  }
}

void dfs(ll x)
{
  dfs2(x,0,0); mx = LLONG_MAX; Find_root(x,0,siz[x]); x=p;
  MAX_DEP = 0; dfs2(x,0,0); vis[x] = 1; 

  for(int i=0;i<=MAX_DEP;i++) a[i] = 0; // printf("%lld %lld\n",x,siz[x]);
  a[0] = 1; cnt[0] ++; 
  for(int k=first[x];k!=-1;k=edge[k].next)
  {
    ll y = edge[k].y;
    if(vis[y]) continue;
    calc(y,x,1);
    FFT(a,b,MAX_DEP,MAX_DEP);
    for(int i=0;i<=MAX_DEP;i++) a[i] += b[i],b[i] = 0;
  }

  for(int k=first[x];k!=-1;k=edge[k].next)
  {
    ll y = edge[k].y;
    if(vis[y]) continue ;
    dfs(y);
  }
}

ll fac[N],inv[N];

node ka[N],kb[N],ba[N],bb[N];

ll ans[N];

int main()
{

  ll n = read(); len = 0; memset(first,-1,sizeof(first));
  ll m=n+n; ll nn,l=0; for(nn=1;nn<=m;nn<<=1) l++;

  for(int i=1;i1) wn[i] = node(cos(pi/i),sin(pi/i));

  for(int i=1;ix = read(); ll y = read();
    ins(x,y); ins(y,x);
  }
  dfs(1);

  fac[0] = 1; for(int i=1;i<=n;i++) fac[i] = fac[i-1] * i % Mod;
  inv[0] = inv[1] = 1; for(int i=2;i<=n;i++) inv[i] = (Mod - Mod/i) * inv[Mod%i] % Mod;
  for(int i=1;i<=n;i++) inv[i] = inv[i-1] * inv[i] % Mod;

  ll qm = (ll)(ceil(sqrt(Mod)));

  for(int i=0;ix = fac[n-1-i] * cnt[i] % Mod; // printf("%lld ",x);
    ka[i].r = x/qm; ba[i].r = x%qm;
//    printf("%lld %lld\n",x/qm,x%qm);
  }// printf("\n"); 

  for(int i=0;ix = inv[i]; // printf("%lld ",x);
    kb[i].r = x/qm; bb[i].r = x%qm;
//    printf("%lld %lld\n",x/qm,x%qm);
  }// printf("\n");


  for(int i=1;i>1] >> 1) | ((i&1) << (l-1));
  DFT(ka,nn,1); DFT(kb,nn,1); DFT(ba,nn,1); DFT(bb,nn,1);

//  for(ll i=0;i//    printf("%f %f %f %f\n",ka[i].r,kb[i].r,ba[i].r,bb[i].r);
//  }

  for(int i=0;i0;
  for(int i=0;i1);
  for(int i=0;i0.5) % Mod * qm % Mod * qm % Mod) % Mod;

  for(int i=0;i0;
  for(int i=0;i1);
  for(int i=0;i0.5) % Mod * qm % Mod) % Mod;

  for(int i=0;i0;
  for(int i=0;i1);
  for(int i=0;i0.5) % Mod) % Mod; 

  for(int i=0;i1-i] = (ans[n-1-i] * inv[n-1] % Mod * fac[n-1-i] % Mod);

  for(int i=0;iprintf("%lld ",ans[n-1-i]);
  // cout<<(double)(clock())<return 0;
}

你可能感兴趣的:(Codechef)