题目:
其中:x1, x2, …,xn是未知数,k1,k2,…,kn是系数,p1,p2,…pn是指数。且方程中的所有数均为整数。
假设未知数1≤ xi ≤M, i=1,,,n,求这个方程的整数解的个数。
第1行包含一个整数n。
第2行包含一个整数M。
第3行到第n+2行,每行包含两个整数,分别表示ki和pi。两个整数之间用一个空格隔开。第3行的数据对应i=1,第n+2行的数据对应i=n。
仅一行,包含一个整数,表示方程的整数解的个数
直接枚举6个x复杂度为M^6,肯定超时。需要一个技巧,对于这种方程,可以分组DFS,使用中间相遇(Meet in the middle)。
e.g. a+b+c+d+e+f = 0 分组为 (a+b+c) + (d+e+f) = 0,对两组分别进行DFS,将两组每一次的DFS得到的所有结果分别保存在res1,res2两个数组里。排序后,若res1[i] + res2[j] = 0 ,则对应原方程的一组解。若res1[l1...r1]中每个数相等,res2[l2...r2]中每个数相等,且res1[l1]+res2[l2]=0,根据乘法原理,对应了原方程的(r1-l1+1)*(r2-l2+1)组解。
暴力枚举为O(M^n), 中间相遇为O( M^(n/2下取整) )+O( M^(n/2上取整) ) = O(M^(n/2))。可将复杂度降为原来的根号级别。
中间相遇之所以能实质上比纯暴力DFS耗时更少,是因为暴力DFS实质上枚举了每一个具体的解;而中间相遇枚举了所有可行解来自的方案,并利用数学上的乘法原理来计算具体解的个数。也就是说,中间相遇法用乘法代替了暴力DFS的自增一,当然更快。
中间相遇的最初模型是图论中求起点到终点路径长为定值的方案数,从起点终点出发分别走一半的路程后用乘法原理求方案数。但是对于这个问题,还有一个更经典更高效的解法:矩阵快速幂。这道题可以用中间相遇法,说明它可以图论建模,但由于转为图论模型后扩展出的节点数太多,不适合用矩阵表示。如果对于路径长度(即本题中的未知数个数)较长,但扩展出的节点较少的问题,我认为矩阵乘法不失为一种优秀的解法。
贴上我的AC代码:
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; #define MAXN 3500000 //150^3 int n, m, ans; int p[10], k[10]; int res1[MAXN], rk1; int res2[MAXN], rk2; int p1[5], k1[5]; int ksm(int d, int z) //快速幂,稍微加快一点 { int res = 1; for (; z>0 ;d*=d, z>>=1) if (z & 1) res = res * d; return res; } int lv; void dfs(int i, int res, int*a, int&rk) //最后两个参数的设置可以让两组DFS共享一个函数 { if (i > lv) { a[++rk] = res; return; } for (int x = 1; x<=m; ++x) dfs(i+1, res + k1[i]*ksm(x,p1[i]), a, rk); } void work() { int i, j = rk2; int cnt1, cnt2; sort(res1+1, res1+rk1+1); sort(res2+1, res2+rk2+1); //这个循环中根据单调性,i不降,j不增,所以总的时间复杂度为O(n) for (i = 1; i<=rk1; ++i) { while (res1[i]+res2[j] > 0 && j>0) --j; //当i, j越小,res1[i]+res2[j]越小。i增加后,j的取值应单调减小。 if (j<=0) break; if (res1[i] + res2[j] != 0) continue; cnt1 = cnt2 = 1; while (res1[i+1] == res1[i] && i<rk1) ++i, ++cnt1; while (res2[j-1] == res2[j] && j>1) --j, ++cnt2; ans = ans + cnt1 * cnt2; // 乘法原理 } } int main() { scanf("%d%d", &n, &m); int i; int part1=n/2, part2=(n+1)/2; //分别为下取整,上取整。 for (i = 1; i<=n; ++i) scanf("%d%d", k+i, p+i); for (i = 1; i<=part1; ++i) p1[i] = p[i], k1[i] = k[i]; lv = part1; dfs(1, 0, res1, rk1); for (i = 1; i<=part2; ++i) p1[i] = p[i+part1], k1[i] = k[i+part1]; lv = part2; dfs(1, 0, res2, rk2); work(); printf("%d\n", ans); return 0; }