题意:给出i个连续格子的涂色方案, 求出涂n个格子的方案总数.
FFT的经典运用,计算 ∑n−1i=1ai×bn−i .
设 dpi 表示涂i个格子的方案数, 显然 dpn=∑ni=1an−i×dpi . 然后就可以用CDQ分治, 在计算 (l,r) 的时候的关键是计算 (l,mid) 对于 (mid+1,r) 的贡献, 这个这个像卷积一样的东西就可以用FFT加速. 按照朴素的来搞就是把 dpl,dpl+1...dpmid 和 a1,a2...ar 在复平面的点存下来然后balabala, 像这样:
void solve (int l, int r) {
if(l == r){
dp[l] += a[l];
dp[l] %= mod;
return;
}
int mid = (l+r) >> 1;
solve (l, mid);
int len = 1;
while (len <= r+1) {len <<= 1;}
for (int i = 0; i < len; i++) {
x1[i] = x2[i] = plex (0, 0);
}
for (int i = l; i <= mid; i++) x1[i] = plex (dp[i], 0);
for (int i = 1; i <= r; i++) x2[i] = plex (a[i], 0);
fft (x1, len, 1), fft (x2, len, 1);
for (int i = 0; i < len; i++) x1[i] = x1[i] * x2[i];
fft (x1, len, -1);
for (int i = mid+1; i <= r; i++) {
dp[i] += (int) (x1[i].x + 0.5);
dp[i] %= mod;
}
solve (mid+1, r);
}
这个东西结果超时了, 因为分治递归每一层都出现了很多无用的内存消耗, 相当于一个超级大常数. 比如计算 dp3 对 dp4 的贡献的时候, 正常来算只需要计算 dp3×a1 就行了. 所以可以搞一个优化, 只存下需要用到a的最小下标到最大下标, 然后统计dp值的时候再变回来就好了.
#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;
#define mod 313
#define pi acos(-1.0)
#define maxn 600005
struct plex {
double x, y;
plex (double _x = 0.0, double _y = 0.0) : x (_x), y (_y) {}
plex operator + (const plex &a) const {
return plex (x+a.x, y+a.y);
}
plex operator - (const plex &a) const {
return plex (x-a.x, y-a.y);
}
plex operator * (const plex &a) const {
return plex (x*a.x-y*a.y, x*a.y+y*a.x);
}
}x1[maxn], x2[maxn];
void change (plex y[], int len) {
if (len == 1)
return ;
plex a1[len], a2[len];
for (int i = 0; i < len; i += 2) {
a1[i/2] = y[i];
a2[i/2] = y[i+1];
}
change (a1, len>>1);
change (a2, len>>1);
for (int i = 0; i < len/2; i++) {
y[i] = a1[i];
y[i+len/2] = a2[i];
}
return ;
}
void fft(plex y[],int len,int on)
{
change(y,len);
for(int h = 2; h <= len; h <<= 1)
{
plex wn(cos(-on*2*pi/h),sin(-on*2*pi/h));
for(int j = 0;j < len;j+=h)
{
plex w(1,0);
for(int k = j;k < j+h/2;k++)
{
plex u = y[k];
plex t = w*y[k+h/2];
y[k] = u+t;
y[k+h/2] = u-t;
w = w*wn;
}
}
}
if(on == -1)
for(int i = 0;i < len;i++)
y[i].x /= len;
}
int n;
int dp[maxn], a[maxn];
void solve (int l, int r) {
if(l == r){
dp[l] += a[l];
dp[l] %= mod;
return;
}
int mid = (l+r) >> 1;
solve (l, mid);
int len = 1;
while (len <= r-l+1) {len <<= 1;}
for (int i = 0; i < len; i++) {
x1[i] = x2[i] = plex (0, 0);
}
for (int i = l; i <= mid; i++) x1[i-l] = plex (dp[i], 0);
for (int i = 1; i <= r-l; i++) x2[i-1] = plex (a[i], 0);
fft (x1, len, 1), fft (x2, len, 1);
for (int i = 0; i < len; i++) x1[i] = x1[i] * x2[i];
fft (x1, len, -1);
for (int i = mid+1; i <= r; i++) {
dp[i] += (int) (x1[i-l-1].x + 0.5);
dp[i] %= mod;
}
solve (mid+1, r);
}
int main () {
while (scanf ("%d", &n) == 1 && n) {
for (int i = 1; i <= n; i++) {
scanf ("%d", &a[i]);
a[i] %= mod;
dp[i] = 0;
}
dp[0] = 0;
solve (1, n);
printf ("%d\n", dp[n]);
}
return 0;
}