【题解】LOJ #6183. 看无可看 生成函数 + 分治FFT

题目

题解
推出f的通项: f n = a × 3 n − b × ( − 1 ) n f_n=a\times 3^n-b\times (-1)^n fn=a×3nb×(1)n
最后我们要求:
∑ s ′ ⊆ s , ∣ s ∣ = k ∏ x ∈ s ′ w x \sum_{s'\subseteq s,|s|=k} \prod_{x\in s'} w^x ss,s=kxswx
这个可以看成生成函数,分治FTT解决(保证不会选重复的数)
g(n)表示选n个数的贡献总和。这个东西是可以卷积的

总结:
注意模数的平方会爆int。
所以用FFT,没有用NTT
因为把Pi打成int。然后resize写错了位置。调了1个多小时,一直以为是FFT出错了。
开始预处理单位根死活不对,丝毫没有察觉到Pi是int、
每一个细节都应该多加注意,并且已经确定没有问题的地方就应该先放一放,然后整个代码的每个细节都必须挨着查!

#include
using namespace std;

#define rep(i,l,r) for(register int i = l ; i <= r ; i++)
#define repd(i,r,l) for(register int i = r ; i >= l ; i--)
#define rvc(i,S) for(register int i = 0 ; i < (int)S.size() ; i++)
#define rvcd(i,S) for(register int i = ((int)S.size()) - 1 ; i >= 0 ; i--)
#define fore(i,x)for (register int i = head[x] ; i ; i = e[i].next)
#define forup(i,l,r) for (register int i = l ; i <= r ; i += lowbit(i))
#define fordown(i,id) for (register int i = id ; i ; i -= lowbit(i))
#define pb push_back
#define prev prev_
#define stack stack_
#define mp make_pair
#define fi first
#define se second
#define lowbit(x) ((x)&(-(x)))

typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
typedef pair<int,int> pr;

const int maxn = 200020;
const double Pi = acos(-1.0);

typedef complex<double> Complex;
typedef vector <int> poly;

const int mod = 99991;
//====================================basic operation===============================
inline void add(int &x, int y) {
  x += y;
  if (x >= mod) {
    x -= mod;
  }
}

inline void sub(int &x, int y) {
  x -= y;
  if (x < 0) {
    x += mod;
  }
}

inline int mul(int x, int y) {
  return (int) ((long long) x * y % mod);
}

inline int power(int x, int y) {
  int res = 1;
  while (y) {
    if (y & 1) {
      res = mul(res, x);
    }
    x = mul(x, x);
    y >>= 1;
  }
  return res;
}

inline int inv(int a) {
  int b = mod, u = 0, v = 1;
  while (a) {
    int t = b / a;
    b -= t * a;
    swap(a, b);
    u -= t * v;
    swap(u, v);
  }
  if (u < 0) {
    u += mod;
  }
  return u;
}
//=======================================================================================


namespace fft { 
int base = 1;
poly rev = {0, 1};
vector<Complex> roots = {Complex(1,0),Complex(1,0)};

void ensure_base(int nbase) { //所有dft需要的预处理
  if (nbase <= base) {
    return;
  }
  rev.resize(1 << nbase);
  for (int i = 0; i < 1 << nbase; ++i) { //预处理翻转位
    rev[i] = rev[i >> 1] >> 1 | (i & 1) << (nbase - 1);
  }
  roots.resize(1 << nbase);
  while (base < nbase) { //预处理单位根
    Complex z(cos(Pi / (1 << base)),sin(Pi / (1 << base)));
    for (int i = 1 << (base - 1); i < 1 << base; ++i) { //处理1 << (base + 1)次单位根
      roots[i << 1] = roots[i];
      roots[i << 1 | 1] = roots[i] * z;
    }
    ++base;
  }
}

void dft(vector<Complex> &a) {
  int n = a.size(), zeros = __builtin_ctz(n);
  ensure_base(zeros);
  int shift = base - zeros;
  for (int i = 0; i < n; ++i) {
    if (i < rev[i] >> shift) {
      swap(a[i], a[rev[i] >> shift]);
    }
  }
  for (int i = 1; i < n; i <<= 1) {
    for (int j = 0; j < n; j += i << 1) {
      for (int k = 0; k < i; ++k) {
        Complex x = a[j + k], y = a[j + k + i] * roots[i + k];
        a[j + k] = x + y;
        a[j + k + i] = x - y;
      }
    }
  }
}
vector <Complex> tmpa,tmpb;
poly multiply(poly a, poly b) { //先转化成complex再运算,注意取模,注意resize的问题
  int need = a.size() + b.size() - 1, nbase = 0;
  while (1 << nbase < need) {
    ++nbase;
  }
  ensure_base(nbase);
  int sz = 1 << nbase;
  tmpa.resize(sz) , tmpb.resize(sz);
  fill(tmpa.begin(),tmpa.end(),Complex(0,0)) , fill(tmpb.begin(),tmpb.end(),Complex(0,0));
  rep(i,0,a.size() - 1) tmpa[i] = Complex(a[i],0);
  rep(i,0,b.size() - 1) tmpb[i] = Complex(b[i],0);

  dft(tmpa) , dft(tmpb);

  rep(i,0,sz - 1) tmpa[i] *= tmpb[i];

  dft(tmpa);
  reverse(tmpa.begin() + 1,tmpa.end());
  a.resize(need);
  rep(i,0,need - 1) a[i] = (ll)(tmpa[i].real() / sz + 0.5) % mod;
  return a;
}

}

using fft::multiply;

poly& operator *= (poly &a, const poly &b) {
  if ((int) min(a.size(), b.size()) < 128) {
    poly c = a;
    a.assign(a.size() + b.size() - 1, 0);
    for (int i = 0; i < (int) c.size(); ++i) {
      for (int j = 0; j < (int) b.size(); ++j) {
        add(a[i + j], mul(c[i], b[j]));
      }
    }
  } else {
    a = multiply(a, b);
  }
  return a;
}

poly operator * (const poly &a, const poly &b) {
  poly c = a;
  return c *= b;
}

int n,k;
int a[maxn],A,B;
poly f;

poly solve(int l,int r,int w){
  if ( l == r ) return poly({1,power(w,a[l])});
  int mid = (l + r) >> 1;
  return solve(l,mid,w) * solve(mid + 1,r,w);
}
int main(){
  //freopen("input.txt","r",stdin);
  scanf("%d %d",&n,&k);
  rep(i,1,n) scanf("%d",&a[i]);
  int f0,f1;
  scanf("%d %d",&f0,&f1);
  A = (ll)(f0 + f1) * power(4,mod - 2) % mod;
  B = (f0 - A + mod) % mod;

  int ans = 0;
  f = solve(1,n,3);

  ans = mul(f[k],A);
  f = solve(1,n,mod - 1);
  add(ans,mul(f[k],B));
  cout<<ans<<endl;
}

你可能感兴趣的:(多项式,生成函数,LOJ,题解,FTT&NTT)