题目链接:https://ac.nowcoder.com/acm/contest/15332/B
给了两个数组a和b,任意从a中取一个数a[i]和b数组中取一个数b[i],求满足a[i] + b[i]是素数的取法有多少种?
数组长度 <= 1e5.
示例:
a = {1, 2, 2}
b = {3, 4, 6}
满足条件的答案有四种:
{1, 4}, {1, 6}, {2, 3}, {2, 3}
朴素的两重暴力枚举时间复杂度是O(N^2)无法AC本题.
我们考虑统计a数组中每个数的个数,1有1个,2有两个
$$f(a) = x^1 + 2x^2$$
同理
$$f(b) = x^3 + x^4 + x^6$$
我们计算:
$$ f(a) * f(b) = x^4 + \color{red}{3x^5} + 2x^6 + \color{red}{x^7} + 2x^8 $$
实际上我们需要统计的答案就是上述多项式中x的幂是素数的项的系数
我们知道,FFT可以把多项式乘法做到O(NlgN)的,那么就可以通过本题了。
#include
using namespace std;
//大于n*m的最小的2的幂
const int N = 30'0010;
const double PI = acos(-1);
int n, m;
struct Complex {
double x, y;
Complex operator+ (const Complex& t) const {
return {x + t.x, y + t.y};
}
Complex operator- (const Complex& t) const {
return {x - t.x, y - t.y};
}
Complex operator* (const Complex& t) const {
return {x * t.x - y * t.y, x * t.y + y * t.x};
}
}a[N], b[N];
int rev[N], bit, tot;
void fft(Complex a[], int inv)
{
for (int i = 0; i < tot; i ++ )
if (i < rev[i])
swap(a[i], a[rev[i]]);
for (int mid = 1; mid < tot; mid <<= 1)
{
auto w1 = Complex({cos(PI / mid), inv * sin(PI / mid)});
for (int i = 0; i < tot; i += mid * 2)
{
auto wk = Complex({1, 0});
for (int j = 0; j < mid; j ++, wk = wk * w1)
{
auto x = a[i + j], y = wk * a[i + j + mid];
a[i + j] = x + y, a[i + j + mid] = x - y;
}
}
}
}
bool isPrime(int x) {
for (int i = 2; i * i <= x; i++) if (x % i == 0) return 0;
return 1;
}
void solve() {
int _n, _m;
scanf("%d%d", &_n, &_m);
n = m = 0;
for (int i = 0; i < _n; i++) {
int val;
scanf("%d", &val);
a[val].x++;
n = max(n, val);
}
for (int i = 0; i < _m; i++) {
int val;
scanf("%d", &val);
b[val].x++;
m = max(m, val);
}
//下面是FFT的板子
//a[0]*x^0 + a[1]*x^1 + a[2]*x^2 + ... + a[n]*x^n
// for (int i = 0; i <= n; i ++) scanf("%lf", &a[i].x);
//b[0]*x^0 + b[1]*x^1 + b[2]*x^2 + ... + b[m]*x^n
// for (int i = 0; i <= m; i ++) scanf("%lf", &b[i].x);
bit = tot = 0;
while ((1 << bit) < n + m + 1) bit++;
tot = 1 << bit;
for (int i = 0; i < tot; i ++ )
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
fft(a, 1), fft(b, 1);
for (int i = 0; i < tot; i++) a[i] = a[i] * b[i];
fft(a, -1);
long long ans = 0;
for (int i = 2; i <= n + m; i++) {
// printf("%d ", (int)(a[i].x / tot + 0.5));
if (isPrime(i)) ans += (long long)(a[i].x / tot + 0.5);
}
printf("%lld\n", ans);
for (int i = 0; i < tot; i++) a[i].x = a[i].y = b[i].x = b[i].y = 0;
}
int main() {
int T;
scanf("%d", &T);
while (T--) {
solve();
}
}