题目
题解
推出f的通项: f n = a × 3 n − b × ( − 1 ) n f_n=a\times 3^n-b\times (-1)^n fn=a×3n−b×(−1)n
最后我们要求:
∑ s ′ ⊆ s , ∣ s ∣ = k ∏ x ∈ s ′ w x \sum_{s'\subseteq s,|s|=k} \prod_{x\in s'} w^x s′⊆s,∣s∣=k∑x∈s′∏wx
这个可以看成生成函数,分治FTT解决(保证不会选重复的数)
g(n)表示选n个数的贡献总和。这个东西是可以卷积的
总结:
注意模数的平方会爆int。
所以用FFT,没有用NTT
因为把Pi打成int。然后resize写错了位置。调了1个多小时,一直以为是FFT出错了。
开始预处理单位根死活不对,丝毫没有察觉到Pi是int、
每一个细节都应该多加注意,并且已经确定没有问题的地方就应该先放一放,然后整个代码的每个细节都必须挨着查!
#include
using namespace std;
#define rep(i,l,r) for(register int i = l ; i <= r ; i++)
#define repd(i,r,l) for(register int i = r ; i >= l ; i--)
#define rvc(i,S) for(register int i = 0 ; i < (int)S.size() ; i++)
#define rvcd(i,S) for(register int i = ((int)S.size()) - 1 ; i >= 0 ; i--)
#define fore(i,x)for (register int i = head[x] ; i ; i = e[i].next)
#define forup(i,l,r) for (register int i = l ; i <= r ; i += lowbit(i))
#define fordown(i,id) for (register int i = id ; i ; i -= lowbit(i))
#define pb push_back
#define prev prev_
#define stack stack_
#define mp make_pair
#define fi first
#define se second
#define lowbit(x) ((x)&(-(x)))
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
typedef pair<int,int> pr;
const int maxn = 200020;
const double Pi = acos(-1.0);
typedef complex<double> Complex;
typedef vector <int> poly;
const int mod = 99991;
//====================================basic operation===============================
inline void add(int &x, int y) {
x += y;
if (x >= mod) {
x -= mod;
}
}
inline void sub(int &x, int y) {
x -= y;
if (x < 0) {
x += mod;
}
}
inline int mul(int x, int y) {
return (int) ((long long) x * y % mod);
}
inline int power(int x, int y) {
int res = 1;
while (y) {
if (y & 1) {
res = mul(res, x);
}
x = mul(x, x);
y >>= 1;
}
return res;
}
inline int inv(int a) {
int b = mod, u = 0, v = 1;
while (a) {
int t = b / a;
b -= t * a;
swap(a, b);
u -= t * v;
swap(u, v);
}
if (u < 0) {
u += mod;
}
return u;
}
//=======================================================================================
namespace fft {
int base = 1;
poly rev = {0, 1};
vector<Complex> roots = {Complex(1,0),Complex(1,0)};
void ensure_base(int nbase) { //所有dft需要的预处理
if (nbase <= base) {
return;
}
rev.resize(1 << nbase);
for (int i = 0; i < 1 << nbase; ++i) { //预处理翻转位
rev[i] = rev[i >> 1] >> 1 | (i & 1) << (nbase - 1);
}
roots.resize(1 << nbase);
while (base < nbase) { //预处理单位根
Complex z(cos(Pi / (1 << base)),sin(Pi / (1 << base)));
for (int i = 1 << (base - 1); i < 1 << base; ++i) { //处理1 << (base + 1)次单位根
roots[i << 1] = roots[i];
roots[i << 1 | 1] = roots[i] * z;
}
++base;
}
}
void dft(vector<Complex> &a) {
int n = a.size(), zeros = __builtin_ctz(n);
ensure_base(zeros);
int shift = base - zeros;
for (int i = 0; i < n; ++i) {
if (i < rev[i] >> shift) {
swap(a[i], a[rev[i] >> shift]);
}
}
for (int i = 1; i < n; i <<= 1) {
for (int j = 0; j < n; j += i << 1) {
for (int k = 0; k < i; ++k) {
Complex x = a[j + k], y = a[j + k + i] * roots[i + k];
a[j + k] = x + y;
a[j + k + i] = x - y;
}
}
}
}
vector <Complex> tmpa,tmpb;
poly multiply(poly a, poly b) { //先转化成complex再运算,注意取模,注意resize的问题
int need = a.size() + b.size() - 1, nbase = 0;
while (1 << nbase < need) {
++nbase;
}
ensure_base(nbase);
int sz = 1 << nbase;
tmpa.resize(sz) , tmpb.resize(sz);
fill(tmpa.begin(),tmpa.end(),Complex(0,0)) , fill(tmpb.begin(),tmpb.end(),Complex(0,0));
rep(i,0,a.size() - 1) tmpa[i] = Complex(a[i],0);
rep(i,0,b.size() - 1) tmpb[i] = Complex(b[i],0);
dft(tmpa) , dft(tmpb);
rep(i,0,sz - 1) tmpa[i] *= tmpb[i];
dft(tmpa);
reverse(tmpa.begin() + 1,tmpa.end());
a.resize(need);
rep(i,0,need - 1) a[i] = (ll)(tmpa[i].real() / sz + 0.5) % mod;
return a;
}
}
using fft::multiply;
poly& operator *= (poly &a, const poly &b) {
if ((int) min(a.size(), b.size()) < 128) {
poly c = a;
a.assign(a.size() + b.size() - 1, 0);
for (int i = 0; i < (int) c.size(); ++i) {
for (int j = 0; j < (int) b.size(); ++j) {
add(a[i + j], mul(c[i], b[j]));
}
}
} else {
a = multiply(a, b);
}
return a;
}
poly operator * (const poly &a, const poly &b) {
poly c = a;
return c *= b;
}
int n,k;
int a[maxn],A,B;
poly f;
poly solve(int l,int r,int w){
if ( l == r ) return poly({1,power(w,a[l])});
int mid = (l + r) >> 1;
return solve(l,mid,w) * solve(mid + 1,r,w);
}
int main(){
//freopen("input.txt","r",stdin);
scanf("%d %d",&n,&k);
rep(i,1,n) scanf("%d",&a[i]);
int f0,f1;
scanf("%d %d",&f0,&f1);
A = (ll)(f0 + f1) * power(4,mod - 2) % mod;
B = (f0 - A + mod) % mod;
int ans = 0;
f = solve(1,n,3);
ans = mul(f[k],A);
f = solve(1,n,mod - 1);
add(ans,mul(f[k],B));
cout<<ans<<endl;
}