【学习总结】生成函数题目,多项式模板

学习资料:
王乐平、策爷冬令营讲义。

多项式技巧

牛顿迭代(泰勒展开)

很多时候推式子就是取对数、积分,exp和泰勒展开的结合
泰勒展开在mod xn意义下只需要保留前n项,性质非常优美。
而插值,对于一个k次的多项式,必须要k + 1个值的代入,如果只求该多项式的前n项,也不能只用
n + 1个点代入

这是一篇非常好的博客。特别是把所有多项式操作都用泰勒展开推导,以后就不用再死记,也不用担心推错了!from yyb

复合逆

拉格朗日反演
【学习总结】生成函数题目,多项式模板_第1张图片

注意当f(x)本身不存在逆元的时候,可以求f(x) / x的逆元
否则(x / f(x))n 的第n - 1
项为0

例题

直接推生成函数

【BZOJ3625】小朋友和二叉树

from cz_xuyixuan
注意这道题是普通生成函数,不是指数型生成函数

代码只有main函数,多项式模板在下面

int n,m;
vector <int> fact,inv_fact;
vector <int> f,g,h,tmp;

void pre_calc(){
 // g[0] = 1;
  rep(i,0,m){
   // g[i] = mul(g[i],inv_fact[i]);
    sub(f[i],mul(g[i],4));
 //   cout<
  }
  f[0] = 1;
 // cout<
  f = sqrt(f);
  add(f[0],1);
  f = inverse(f);
  rep(i,0,m) f[i] = mul(f[i],2);
}

int main(){
  scanf("%d %d",&n,&m);
  f.resize(m + 1) , g.resize(m + 1);
  rep(i,1,n){
    int x;
    scanf("%d",&x);
    if ( x <= m ) g[x]++;
  }
  pre_calc();
  rep(i,1,m){
   // f[i] = mul(f[i],fact[i]);
    printf("%d\n",f[i]);
  }
}

bzoj 3684 大朋友和多叉树

复合逆裸题。
注意F(x) / x才可以求逆

int n,m;
poly fact,inv_fact;
poly f,g,h,tmp;

void pre_calc(){
  add(g[0],1);
  g = inverse(g);

  g = power(g,n);
  int ans = mul(power(n,mod - 2),g[n - 1]);

  cout<<ans<<endl;
}
int main(){
 scanf("%d %d",&n,&m);
 f.resize(n + 1) , g.resize(n + 1);
 rep(i,1,m){
  int x;
  scanf("%d",&x);
  g[x - 1] = mod - 1;
 }
 pre_calc();
}

多项式模板

注意事项:
求导和积分多项式的次数变化
开根号如果常数项不为完全平方数,需要用二次剩余开根号的模板。详见或这位大佬
求逆常数项不能为0 , 否则不存在逆元
special thanks to wxh010910

#include
using namespace std;

#define rep(i,l,r) for(register int i = l ; i <= r ; i++)
#define repd(i,r,l) for(register int i = r ; i >= l ; i--)
#define rvc(i,S) for(register int i = 0 ; i < (int)S.size() ; i++)
#define rvcd(i,S) for(register int i = ((int)S.size()) - 1 ; i >= 0 ; i--)
#define fore(i,x)for (register int i = head[x] ; i ; i = e[i].next)
#define forup(i,l,r) for (register int i = l ; i <= r ; i += lowbit(i))
#define fordown(i,id) for (register int i = id ; i ; i -= lowbit(i))
#define pb push_back
#define prev prev_
#define stack stack_
#define mp make_pair
#define fi first
#define se second
#define lowbit(x) ((x)&(-(x)))

typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
typedef pair<int,int> pr;

const int maxn = 200020;

typedef vector <int> poly;
const int mod = 950009857;

//NOTES: 任意乘法需要用mul,或者强制用long long。
//注意取模
//====================================basic operation===============================
inline void add(int &x, int y) {
  x += y;
  if (x >= mod) {
    x -= mod;
  }
}

inline void sub(int &x, int y) {
  x -= y;
  if (x < 0) {
    x += mod;
  }
}

inline int mul(int x, int y) {
  return (int) ((long long) x * y % mod);
}

inline int power(int x, int y) {
  int res = 1;
  while (y) {
    if (y & 1) {
      res = mul(res, x);
    }
    x = mul(x, x);
    y >>= 1;
  }
  return res;
}

inline int inv(int a) {
  int b = mod, u = 0, v = 1;
  while (a) {
    int t = b / a;
    b -= t * a;
    swap(a, b);
    u -= t * v;
    swap(u, v);
  }
  if (u < 0) {
    u += mod;
  }
  return u;
}
//=======================================================================================

namespace ntt { 
int base = 1, root = -1, max_base = -1;
poly rev = {0, 1}, roots = {0, 1};

void init() {
  int temp = mod - 1;
  max_base = 0;
  while (temp % 2 == 0) {
    temp >>= 1;
    ++max_base;
  }
  root = 2;
  while (true) {
    if (power(root, 1 << max_base) == 1 && power(root, 1 << (max_base - 1)) != 1) {
      break;
    }
    ++root;
  }
}

void ensure_base(int nbase) { //所有dft需要的预处理
  if (max_base == -1) {
    init();
  }
  if (nbase <= base) {
    return;
  }
  assert(nbase <= max_base);
  rev.resize(1 << nbase);
  for (int i = 0; i < 1 << nbase; ++i) { //预处理翻转位
    rev[i] = rev[i >> 1] >> 1 | (i & 1) << (nbase - 1);
  }
  roots.resize(1 << nbase);
  while (base < nbase) { //预处理单位根
    int z = power(root, 1 << (max_base - 1 - base));
    for (int i = 1 << (base - 1); i < 1 << base; ++i) {
      roots[i << 1] = roots[i];
      roots[i << 1 | 1] = mul(roots[i], z);
    }
    ++base;
  }
}

void dft(poly &a) {
  int n = a.size(), zeros = __builtin_ctz(n);
  ensure_base(zeros);
  int shift = base - zeros;
  for (int i = 0; i < n; ++i) {
    if (i < rev[i] >> shift) {
      swap(a[i], a[rev[i] >> shift]);
    }
  }
  for (int i = 1; i < n; i <<= 1) {
    for (int j = 0; j < n; j += i << 1) {
      for (int k = 0; k < i; ++k) {
        int x = a[j + k], y = mul(a[j + k + i], roots[i + k]);
        a[j + k] = (x + y) % mod;
        a[j + k + i] = (x + mod - y) % mod;
      }
    }
  }
}

poly multiply(poly a, poly b) {
  int need = a.size() + b.size() - 1, nbase = 0;
  while (1 << nbase < need) {
    ++nbase;
  }
  ensure_base(nbase);
  int sz = 1 << nbase;
  a.resize(sz);
  b.resize(sz);
  bool equal = a == b;
  dft(a);
  if (equal) {
    b = a;
  } else {
    dft(b);
  }
  int inv_sz = inv(sz);
  for (int i = 0; i < sz; ++i) {
    a[i] = mul(mul(a[i], b[i]), inv_sz);
  }
  reverse(a.begin() + 1, a.end()); //相当于NTT(a,-1)
  dft(a);
  a.resize(need);
  return a;
}

poly inverse(poly a) { //常数项不能为0,否则不存在逆元!
  int n = a.size(), m = (n + 1) >> 1;
  if (n == 1) {
    return poly(1, inv(a[0]));
  } else {
    poly b = inverse(poly(a.begin(), a.begin() + m));
    int need = n << 1, nbase = 0;
    while (1 << nbase < need) {
      ++nbase;
    }
    ensure_base(nbase);
    int sz = 1 << nbase;
    a.resize(sz);
    b.resize(sz);
    dft(a);
    dft(b);
    int inv_sz = inv(sz);
    for (int i = 0; i < sz; ++i) {
      a[i] = mul(mul(mod + 2 - mul(a[i], b[i]), b[i]), inv_sz);
    }
    reverse(a.begin() + 1, a.end());
    dft(a);
    a.resize(n);
    return a;
  }
}
}

using ntt::multiply;
using ntt::inverse;

poly& operator += (poly &a, const poly &b) {
  if (a.size() < b.size()) {
    a.resize(b.size());
  }
  for (int i = 0; i < (int) b.size(); ++i) {
    add(a[i], b[i]);
  }
  return a;
}

poly operator + (const poly &a, const poly &b) {
  poly c = a;
  return c += b;
}

poly& operator -= (poly &a, const poly &b) {
  if (a.size() < b.size()) {
    a.resize(b.size());
  }
  for (int i = 0; i < (int) b.size(); ++i) {
    sub(a[i], b[i]);
  }
  return a;
}

poly operator - (const poly &a, const poly &b) {
  poly c = a;
  return c -= b;
}

poly& operator *= (poly &a, const poly &b) {
  if ((int) min(a.size(), b.size()) < 128) {
    poly c = a;
    a.assign(a.size() + b.size() - 1, 0);
    for (int i = 0; i < (int) c.size(); ++i) {
      for (int j = 0; j < (int) b.size(); ++j) {
        add(a[i + j], mul(c[i], b[j]));
      }
    }
  } else {
    a = multiply(a, b);
  }
  return a;
}

poly operator * (const poly &a, const poly &b) {
  poly c = a;
  return c *= b;
}

poly& operator /= (poly &a, const poly &b) {
  int n = a.size(), m = b.size();
  if (n < m) {
    a.clear();
  } else {
    poly c = b;
    reverse(a.begin(), a.end());
    reverse(c.begin(), c.end());
    c.resize(n - m + 1);
    a *= inverse(c);
    a.erase(a.begin() + n - m + 1, a.end());
    reverse(a.begin(), a.end());
  }
  return a;
}

poly operator / (const poly &a, const poly &b) {
  poly c = a;
  return c /= b;
}

poly& operator %= (poly &a, const poly &b) {
  int n = a.size(), m = b.size();
  if (n >= m) {
    poly c = (a / b) * b;
    a.resize(m - 1);
    for (int i = 0; i < m - 1; ++i) {
      sub(a[i], c[i]);
    }
  }
  return a;
}

poly operator % (const poly &a, const poly &b) {
  poly c = a;
  return c %= b;
}

poly derivative(const poly &a) {
  int n = a.size();
  poly b(n - 1);
  for (int i = 1; i < n; ++i) {
    b[i - 1] = mul(a[i], i);
  }
  return b;
}

poly primitive(const poly &a) {
  int n = a.size();
  poly b(n + 1), invs(n + 1);
  for (int i = 1; i <= n; ++i) {
    invs[i] = i == 1 ? 1 : mul(mod - mod / i, invs[mod % i]);
    b[i] = mul(a[i - 1], invs[i]);
  }
  return b;
}

poly logarithm(const poly &a) {
  poly b = primitive(derivative(a) * inverse(a));
  b.resize(a.size());
  return b;
}

poly exponent(const poly &a) {
  poly b(1, 1);
  while (b.size() < a.size()) {
    poly c(a.begin(), a.begin() + min(a.size(), b.size() << 1));
    add(c[0], 1);
    poly old_b = b;
    b.resize(b.size() << 1);
    c -= logarithm(b);
    c *= old_b;
    for (int i = b.size() >> 1; i < (int) b.size(); ++i) {
      b[i] = c[i];
    }
  }
  b.resize(a.size());
  return b;
}

poly power(const poly &a, int m) { //高端的power写法
  int n = a.size(), p = -1;
  poly b(n);
  for (int i = 0; i < n; ++i) {
    if (a[i]) {
      p = i;
      break;
    }
  }
  if (p == -1) {
    b[0] = !m;
    return b;
  }
  if ((long long) m * p >= n) {
    return b;
  }
  int mu = power(a[p], m), di = inv(a[p]);
  poly c(n - m * p);
  for (int i = 0; i < n - m * p; ++i) {
    c[i] = mul(a[i + p], di);
  }
  c = logarithm(c);
  for (int i = 0; i < n - m * p; ++i) {
    c[i] = mul(c[i], m);
  }
  c = exponent(c);
  for (int i = 0; i < n - m * p; ++i) {
    b[i + m * p] = mul(c[i], mu);
  }
  return b;
}

poly sqrt(const poly &a) {
  poly b(1,(int)sqrt(a[0])); //常数项是完全平方数,如果不是,则需要BSGS开根号
  while (b.size() < a.size()) {
    poly c(a.begin(), a.begin() + min(a.size(), b.size() << 1));
    poly old_b = b;
    b.resize(b.size() << 1);
    c *= inverse(b);
    for (int i = b.size() >> 1; i < (int) b.size(); ++i) {
      b[i] = mul(c[i], (mod + 1) >> 1);
    }
  }
  b.resize(a.size());
  return b;
}

poly multiply_all(int l, int r, vector<poly > &all) {
  if (l > r) {
    return poly();
  } else if (l == r) {
    return all[l];
  } else {
    int y = (l + r) >> 1;
    return multiply_all(l, y, all) * multiply_all(y + 1, r, all);
  }
}

poly evaluate(const poly &f, const poly &x) {
  int n = x.size();
  if (!n) {
    return poly();
  }
  vector<poly> up(n * 2);
  for (int i = 0; i < n; ++i) {
    up[i + n] = poly{(mod - x[i]) % mod, 1};
  }
  for (int i = n - 1; i; --i) {
    up[i] = up[i << 1] * up[i << 1 | 1];
  }
  vector<poly> down(n * 2);
  down[1] = f % up[1];
  for (int i = 2; i < n * 2; ++i) {
    down[i] = down[i >> 1] % up[i];
  }
  poly y(n);
  for (int i = 0; i < n; ++i) {
    y[i] = down[i + n][0];
  }
  return y;
}

poly interpolate(const poly &x, const poly &y) {
  int n = x.size();
  vector<poly> up(n * 2);
  for (int i = 0; i < n; ++i) {
    up[i + n] = poly{(mod - x[i]) % mod, 1};
  }
  for (int i = n - 1; i; --i) {
    up[i] = up[i << 1] * up[i << 1 | 1];
  }
  poly a = evaluate(derivative(up[1]), x);
  for (int i = 0; i < n; ++i) {
    a[i] = mul(y[i], inv(a[i]));
  }
  vector<poly> down(n * 2);
  for (int i = 0; i < n; ++i) {
    down[i + n] = poly(1, a[i]);
  }
  for (int i = n - 1; i; --i) {
    down[i] = down[i << 1] * up[i << 1 | 1] + down[i << 1 | 1] * up[i << 1];
  }
  return down[1];
}

cogs 2189

只有main函数的部分

int n,k;
poly f;

int main(){
  freopen("polynomial.in","r",stdin);
  freopen("polynomial.out","w",stdout);
//  freopen("input.txt","r",stdin);
  scanf("%d %d",&n,&k);
  f.resize(n);
  rep(i,0,n - 1) scanf("%d",&f[i]);
  f = sqrt(f);
  f = inverse(f);
  f = primitive(f);
  f = exponent(f);
  f = inverse(f);
  add(f[0],1);
  f = logarithm(f);
  add(f[0],1);
  f = power(f,k);
  f = derivative(f); 
  rep(i,0,n - 2) printf("%d ",f[i]);
  puts("0");
}

你可能感兴趣的:(多项式,知识点总结,数学,多项式,生成函数)