HDU 5730 (CDQ分治 FFT)

题目链接:点击这里

题意:给出i个连续格子的涂色方案, 求出涂n个格子的方案总数.
FFT的经典运用,计算 n1i=1ai×bni .
dpi 表示涂i个格子的方案数, 显然 dpn=ni=1ani×dpi . 然后就可以用CDQ分治, 在计算 (l,r) 的时候的关键是计算 (l,mid) 对于 (mid+1,r) 的贡献, 这个这个像卷积一样的东西就可以用FFT加速. 按照朴素的来搞就是把 dpl,dpl+1...dpmid a1,a2...ar 在复平面的点存下来然后balabala, 像这样:

void solve (int l, int r) {
    if(l == r){
        dp[l] += a[l];
        dp[l] %= mod;
        return;
    }
    int mid = (l+r) >> 1;
    solve (l, mid);
    int len = 1;
    while (len <= r+1) {len <<= 1;}
    for (int i = 0; i < len; i++) {
        x1[i] = x2[i] = plex (0, 0);
    }
    for (int i = l; i <= mid; i++) x1[i] = plex (dp[i], 0);
    for (int i = 1; i <= r; i++) x2[i] = plex (a[i], 0);
    fft (x1, len, 1), fft (x2, len, 1);
    for (int i = 0; i < len; i++) x1[i] = x1[i] * x2[i];
    fft (x1, len, -1);
    for (int i = mid+1; i <= r; i++) {
        dp[i] += (int) (x1[i].x + 0.5);
        dp[i] %= mod;
    }
    solve (mid+1, r);
}

这个东西结果超时了, 因为分治递归每一层都出现了很多无用的内存消耗, 相当于一个超级大常数. 比如计算 dp3 dp4 的贡献的时候, 正常来算只需要计算 dp3×a1 就行了. 所以可以搞一个优化, 只存下需要用到a的最小下标到最大下标, 然后统计dp值的时候再变回来就好了.

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
using namespace std;
#define mod 313
#define pi acos(-1.0)
#define maxn 600005

struct plex {
    double x, y;
    plex (double _x = 0.0, double _y = 0.0) : x (_x), y (_y) {}
    plex operator + (const plex &a) const {
        return plex (x+a.x, y+a.y);
    }
    plex operator - (const plex &a) const {
        return plex (x-a.x, y-a.y);
    }
    plex operator * (const plex &a) const {
        return plex (x*a.x-y*a.y, x*a.y+y*a.x);
    }
}x1[maxn], x2[maxn];

void change (plex y[], int len) {
    if (len == 1)
        return ;
    plex a1[len], a2[len];
    for (int i = 0; i < len; i += 2) {
        a1[i/2] = y[i];
        a2[i/2] = y[i+1];
    }
    change (a1, len>>1);
    change (a2, len>>1);
    for (int i = 0; i < len/2; i++) {
        y[i] = a1[i];
        y[i+len/2] = a2[i];
    }
    return  ;
}

void fft(plex y[],int len,int on)
{
    change(y,len);
    for(int h = 2; h <= len; h <<= 1)
    {
        plex wn(cos(-on*2*pi/h),sin(-on*2*pi/h));
        for(int j = 0;j < len;j+=h)
        {
            plex w(1,0);
            for(int k = j;k < j+h/2;k++)
            {
                plex u = y[k];
                plex t = w*y[k+h/2];
                y[k] = u+t;
                y[k+h/2] = u-t;
                w = w*wn;
            }
        }
    }
    if(on == -1)
        for(int i = 0;i < len;i++)
            y[i].x /= len;
}

int n;
int dp[maxn], a[maxn];

void solve (int l, int r) {
    if(l == r){
        dp[l] += a[l];
        dp[l] %= mod;
        return;
    }
    int mid = (l+r) >> 1;
    solve (l, mid);
    int len = 1;
    while (len <= r-l+1) {len <<= 1;}
    for (int i = 0; i < len; i++) {
        x1[i] = x2[i] = plex (0, 0);
    }
    for (int i = l; i <= mid; i++) x1[i-l] = plex (dp[i], 0);
    for (int i = 1; i <= r-l; i++) x2[i-1] = plex (a[i], 0);
    fft (x1, len, 1), fft (x2, len, 1);
    for (int i = 0; i < len; i++) x1[i] = x1[i] * x2[i];
    fft (x1, len, -1);
    for (int i = mid+1; i <= r; i++) {
        dp[i] += (int) (x1[i-l-1].x + 0.5);
        dp[i] %= mod;
    }
    solve (mid+1, r);
}

int main () {
    while (scanf ("%d", &n) == 1 && n) {
        for (int i = 1; i <= n; i++) {
            scanf ("%d", &a[i]);
            a[i] %= mod;
            dp[i] = 0;
        }
        dp[0] = 0;
        solve (1, n);
        printf ("%d\n", dp[n]);
    }
    return 0;
}

你可能感兴趣的:(FFT,&&,NTT,分治)