最近在看几道整体二分还有cdq分治的东西,突然间想起前几个礼拜的ZOJ题,然后看了一下代码,经过了一些深思熟虑之后,发现自己终于看懂了,下面就用别人的代码来剖析一下整个解题的思路吧,具体的内容我再看看。
首先要解决这个问题需要有一些卷积的知识,或者说是多项式乘法,一个很典型的多项式乘法的东东就是FFT,然后原来在数论意义下(即mod P)的意义下,也有相应的NTT(快速数论变换),思想是和FFT一致的,不过在这里原根稍微不一样,而且也不用去管复数,当然我是不会懂的,下面抄个链接:
NTT(快速数论变换) http://blog.csdn.net/zz_1215/article/details/40430041
现在假设自己懂了NTT,然后我们把它当作黑盒,然后就把它当作可以实现mod P意义下卷积的一个工具,然后去理解一下题目的做法。
经过一些理论推导,我们可以发现,实际上我们要求的东西是 dp[n]=n!-(dp[1]*(n-1)!+dp[2]*(n-2)!+... dp[n-1]*1!).
实际上dp[n]=n!- dp[i]和i!的卷积的第n项。这样的一个算法暴力做的话要算到n的话是O(n^2)的,下面看下cdq分治。
个人对cdq分治的理解是这样的:
T(n)=2T(n/2)+O(f(n)) 一个传统的典型的分治算法里,O(f(n))是指的将两个子问题合并的代价,非常典型的就是归并排序。而在cdq分治里,O(f(n))就不一定是合并的代价了,在归并排序里,左子问题对右子问题是没有影响的,而现实的分治里,可能会出现左子问题对右子问题产生影响的情况,及前面的操作是直接对右边操作产生影响的,我们必须先做了左子问题,然后把左子问题的影响加到右子问题,然后才能再递归右子问题。下面的框架给了两种分治的思想吧。
// traditional divide-and-conquer void solve(l,r) { int mid=(l+r)>>1; solve(l,mid); solve(mid+1,r); combine 2 subproblem. } // cdq divide-and conquer void solve1(l,r) { int mid=(l+r)>>1; solve(l,mid); add the affect of [l,mid] to [mid+1,r] solve(mid+1,r); } void solve2(l,r) { int mid=(l+r)>>1; add the affect of [l,mid] to [mid+1,r]; solve(l,mid); solve(mid+1,r); }
下面我们来看看别人的代码里是怎么做的,下面的代码抄自下面的链接,注释是自己给别人加上的,方便一下理解吧。
ZOJ3874:http://acm.hust.edu.cn/vjudge/problem/viewSource.action?id=3709562
// dp[l]要求的 // f[l] 是l的阶乘 void solve (int l, int r) { // 递归边界,l==r时,说明所有比l小的卷积都算完了,所以dp[l]=f[l]-dp[l]; if (l == r) { dp[l] = (f[l] - dp[l] + P) % P; return; } int m = (l + r) >> 1; // 递归左子问题,现在要做的是算出dp[l...mid]和f[]的卷积加到dp[mid+1...r]上 solve (l, m); // 下面的部分就是将dp[l...mid]赋给a,将f的值赋给b,然后做NTT,然后算完之后再逆变换回来 // 做完逆变换后,a[x]存的就是dp和f的卷积的第x项, int s = 1, n = m - l + 1; while (s <= n * 2) s <<= 1; a[0] = b[0] = 0; for (int i = 1, j = l; i < s; i++, j++) a[i] = (j <= m ? dp[j] : 0); for (int i = 1; i < s; i++) b[i] = f[i]; NTT (a, s); NTT (b, s); for (int i = 0; i < s; i++) a[i] = a[i] * b[i] % P; NTT (a, s, true); // end // 将影响加到dp[mid+1...r]上。 for (int i = m + 1, j = m - l + 2; i <= r; i++, j++) (dp[i] += a[j]) %= P; solve (m + 1, r); }
最后不难发现将左子问题的影响加到右子问题上其实就是一个O(nlogn)的过程。
T(n)=2T(n/2)+O(nlogn)
所以最后出来的复杂度应该是O(nlog^2n)的。