给定a1, a2, a3, a4, a5,求符合条件的解(x1, x2, x3, x4, x5)的个数,使得a1*x1^3 + a2*x2^3 + a3*x3^3 + a4*x4^3 + a5*x5^3 = 0。其中-50 <= x1, x2, x3, x4, x5 <= 50且x1-x5都不为0。
x1-x5各有100个可能值,相互的解法是使用5重循环逐一判断,但这样用的时间太长了。因此可以把它们分为两部分(x1, x2)和(x3, x4, x5),求出所有的sum = a1*x1^3 + a2*x2^3,将非负的sum值和负的sum值映射到非负数和负数两个哈希表上,把sum相同的压缩到一个结点,只记录相同的sum的个数。然后求出所有的sum = a3*x3^3 + a4*x4^3 + a5*x5^3,然后分别到非负哈希表和负数哈希表上查找对应的(x1, x2)组合的个数。把所有的结果累加起来就得到最终的答案了。
#include <iostream> #include <cstdio> #include <algorithm> using namespace std; const int N = 10005; const int H = 9997; struct Node { long sum; int cnt; int next; }; Node nodeP[N], nodeN[N]; int curP, curN; int hashTableP[H], hashTableN[H]; int a1, a2, a3, a4, a5; int cube[101]; long long ans; void initHash() { curP = curN = 0; ans = 0; for (int i = 0; i < H; ++i) hashTableP[i] = hashTableN[i] = -1; for (int i = 0; i <= 100; ++i) { int t = i - 50; cube[i] = t * t * t; } } void insertHash(int i, int j) { long sum = cube[i] * a1 + cube[j] * a2; int h = sum % H; int next; if (h >= 0) { next = hashTableP[h]; while (next != -1) { if (nodeP[next].sum == sum) { ++nodeP[next].cnt; return; } next = nodeP[next].next; } nodeP[curP].cnt = 1; nodeP[curP].sum = sum; nodeP[curP].next = hashTableP[h]; hashTableP[h] = curP; ++curP; } else { h = -h; next = hashTableN[h]; while (next != -1) { if (nodeN[next].sum == sum) { ++nodeN[next].cnt; return; } next = nodeN[next].next; } nodeN[curN].cnt = 1; nodeN[curN].sum = sum; nodeN[curN].next = hashTableN[h]; hashTableN[h] = curN; ++curN; } } void getAns(int i, int j, int k) { long sum = cube[i] * a3 + cube[j] * a4 + cube[k] * a5; int h = sum % H; int next; if (h > 0) { next = hashTableN[h]; while (next != -1) { if (nodeN[next].sum + sum == 0) { ans += nodeN[next].cnt; return; } next = nodeN[next].next; } } else { h = -h; next = hashTableP[h]; while (next != -1) { if (nodeP[next].sum + sum == 0) { ans += nodeP[next].cnt; return; } next = nodeP[next].next; } } } int main() { initHash(); scanf("%d%d%d%d%d", &a1, &a2, &a3, &a4, &a5); for (int i = 0; i <= 100; ++i) { if (i == 50) continue; for (int j = 0; j <= 100; ++j) { if (j == 50) continue; insertHash(i, j); } } for (int i = 0; i <= 100; ++i) { if (i == 50) continue; for (int j = 0; j <= 100; ++j) { if (j == 50) continue; for (int k = 0; k <= 100; ++k) { if (k == 50) continue; getAns(i, j, k); } } } printf("%lld\n", ans); return 0; }