4836: [Lydsy2017年4月月赛]二元运算

4836: [Lydsy2017年4月月赛]二元运算

Time Limit: 8 Sec Memory Limit: 128 MB
Submit: 286 Solved: 92
[Submit][Status][Discuss]
Description

定义二元运算 opt 满足
这里写图片描述

现在给定一个长为 n 的数列 a 和一个长为 m 的数列 b ,接下来有 q 次询问。每次询问给定一个数字 c
你需要求出有多少对 (i, j) 使得 a_i opt b_j=c 。
Input

第一行是一个整数 T (1≤T≤10) ,表示测试数据的组数。
对于每组测试数据:
第一行是三个整数 n,m,q (1≤n,m,q≤50000) 。
第二行是 n 个整数,表示 a_1,a_2,?,a_n (0≤a_1,a_2,?,a_n≤50000) 。
第三行是 m 个整数,表示 b_1,b_2,?,b_m (0≤b_1,b_2,?,b_m≤50000) 。
第四行是 q 个整数,第 i 个整数 c_i (0≤c_i≤100000) 表示第 i 次查询的数。
Output

对于每次查询,输出一行,包含一个整数,表示满足条件的 (i, j) 对的个数。

Sample Input

2

2 1 5

1 3

2

1 2 3 4 5

2 2 5

1 3

2 4

1 2 3 4 5
Sample Output

1

0

1

0

0

1

0

1

0

1
HINT

Source

鸣谢Tangjz提供试题

[Submit][Status][Discuss]

先考虑 aibj 的情形
考虑到每个数字不超过 50000 ,记 c=50000,bi=cbi
那么, ai+bj=c+aibj
定义 Ai 为数组 a 中数字 i 出现的次数
然后让 A B 做卷积,其中第 i 项系数就是差值为 ic 的数量了
所以有用的系数都是大于等于 50000 的项
这一部分的复杂度为 O(klogk)
不过对于 ai<bj 的情形就不能这样了
考虑对值域分治然后做卷积
每次在 A 中填入 [l,mid] B 中填入 [mid+1,r]
这样做卷积就能避开所有非法情形了
这一部分复杂度 O(klog2k)
很科学对不对,很科学对不对???

4836: [Lydsy2017年4月月赛]二元运算_第1张图片

我也不知道为什么。。。。反正真的卡了好久好久的常数
以及一开始的几个 TLE ,其实答案也是错的。。不过根本跑不完
首先,这道题卡什么都卡得丧心病狂
模数用 1004535809 是根本不够的
于是百度找到了这个 3221225473 ,原根是 5
存储得用 unsigned int 平方得用 unsigned long long
然后开始解决漫长的卡常数问题。。。

  1. 千万不要在 NTT 过程中写很多除法, klog2k 次不是开玩笑的
  2. 传指针可能有些慢?所以分开来写两种 NTT
  3. 随意试了一下发现,让初始值域为 [0,65535] 使得每一层值域大小都为 2k 能够显著提升效率
  4. 用三目运算符配合加减法加速加法和减法的取模

于是乎终于卡过了。。。(不过还是非常非常慢的)

#include
#include
#include
#include
#include
using namespace std;

const int maxn = (1 << 17);
typedef long long LL;
typedef unsigned int u32;
typedef unsigned long long u64;
const u32 W = 1;
const u32 TT = 10;
const u64 WW = 1;
const u32 mo = 3221225473LL;

int n,m,q,T,A[maxn],B[maxn],ca[maxn],cb[maxn],f[maxn],g[maxn],LOG[maxn + 1];
u32 Ans[maxn],w[maxn + 1],_w[maxn + 1],Inv[maxn + 233],a[maxn],b[maxn],c[maxn];

#define Mul(x,y) (WW * (x) * (y) % mo)
#define max(a,b) ((a) > (b) ? (a) : (b))
#define swap(x,y) ((x) ^= (y),(y) ^= (x),(x) ^= (y))
#define Dec(x,y) ((x) < (y) ? (x) + mo - (y) : (x) - (y))
#define Add(x,y) (mo - (x) > (y) ? (x) + (y) : (x) - mo + (y))

inline u32 ksm(u32 x,u32 y)
{
    u32 ret = W;
    for (; y; y >>= W)
    {
        if (y & W) ret = Mul(ret,x);
        x = Mul(x,x);
    }
    return ret;
}

inline int getint()
{
    char ch = getchar(); int ret = 0;
    while (ch < '0' || '9' < ch) ch = getchar();
    while ('0' <= ch && ch <= '9')
        ret = ret * 10 + ch - '0',ch = getchar();
    return ret;
}

char s[20];
inline void Print(u32 x)
{
    if (!x) {puts("0"); return;} int len = 0;
    while (x) s[++len] = x % TT,x /= TT;
    for (int i = len; i; i--) putchar(s[i] + '0'); puts("");
}

inline void Rader(u32 *F,int N)
{
    int j = (N >> 1);
    for (int i = 1; i < N - 1; i++)
    {
        if (i < j) swap(F[i],F[j]); int k = (N >> 1);
        while (j >= k) j -= k,k >>= 1; j += k;
    }
}

inline void NTT(u32 *F,int N)
{
    Rader(F,N);
    for (int k = 2; k <= N; k <<= 1)
    {
        int tmp = k >> 1,G = maxn / k;
        for (int i = 0; i < N; i += k)
        {
            int now = 0;
            for (int j = i; j < i + tmp; j++)
            {
                u32 u = F[j],v = Mul(w[now],F[j + tmp]);
                F[j] = Add(u,v); F[j + tmp] = Dec(u,v); now += G;
            }
        }
    }
}

inline void _NTT(u32 *F,int N)
{
    Rader(F,N);
    for (int k = 2; k <= N; k <<= 1)
    {
        int tmp = k >> 1,G = maxn / k;
        for (int i = 0; i < N; i += k)
        {
            int now = 0;
            for (int j = i; j < i + tmp; j++)
            {
                u32 u = F[j],v = Mul(_w[now],F[j + tmp]);
                F[j] = Add(u,v); F[j + tmp] = Dec(u,v); now += G;
            }
        }
    }
    for (int i = 0; i < N; i++) F[i] = Mul(F[i],Inv[N]);
}

inline void Work()
{
    int N = 131072,t = 65536;
    for (int i = 1; i <= n; i++) ++a[A[i]];
    for (int i = 1; i <= m; i++) ++b[t - B[i]];
    NTT(a,N); NTT(b,N);
    for (int i = 0; i < N; i++) c[i] = Mul(a[i],b[i]); _NTT(c,N);
    for (int i = t; i < N; i++) Ans[i - t] += c[i];
    for (int i = 0; i < N; i++) a[i] = b[i] = c[i] = 0;
}

inline void Calc(int l,int r)
{
    if (l == r) return;
    int mid = l + r >> 1,M = r - l + 1 << 1,tf = 0,tg = 0;
    for (int i = l; i <= mid; i++)
        if (ca[i]) a[i - l] = ca[i],f[++tf] = i;
    for (int i = mid + 1; i <= r; i++)
        if (cb[i]) b[i - l] = cb[i],g[++tg] = i;
    if (1LL * tf * tg > 1LL * M * LOG[M])
    {
        NTT(a,M); NTT(b,M);
        for (int i = 0; i < M; i++) c[i] = Mul(a[i],b[i]); _NTT(c,M);
        for (int i = 0; i < M; i++) Ans[i + l * 2] += c[i],a[i] = b[i] = c[i] = 0;
    }
    else
    {
        for (int i = 1; i <= tf; i++)
            for (int j = 1; j <= tg; j++)
                Ans[f[i] + g[j]] += W * ca[f[i]] * cb[g[j]];
        for (int i = 1; i <= tf; i++) a[f[i] - l] = 0;
        for (int i = 1; i <= tg; i++) b[g[i] - l] = 0;
    }
    Calc(l,mid); Calc(mid + 1,r);
}

inline void Solve()
{
    n = getint(); m = getint(); q = getint();
    for (int i = 1; i <= n; i++) A[i] = getint(),++ca[A[i]];
    for (int i = 1; i <= m; i++) B[i] = getint(),++cb[B[i]];
    Work(); Calc(0,65535); while (q--) Print(Ans[getint()]);
}

inline void Clear()
{
    memset(ca,0,sizeof(ca));
    memset(cb,0,sizeof(cb));
    memset(Ans,0,sizeof(Ans));
}

int main()
{
    #ifdef DMC
        freopen("binop1.in","r",stdin);
        freopen("test.out","w",stdout);
    #endif

    w[0] = 1; w[1] = ksm(5,(mo - 1) / maxn);
    for (int i = 2; i <= maxn; i++) w[i] = Mul(w[i - 1],w[1]);
    for (int i = 0; i <= maxn; i++) _w[i] = w[maxn - i];
    for (int i = 1; i <= maxn; i <<= 1) Inv[i] = ksm(i,mo - 2);
    for (int i = 2; i <= maxn; i <<= 1) LOG[i] = LOG[i >> 1] + 1;
    T = getint(); while (T--) Solve(),Clear();
    return 0;
}

你可能感兴趣的:(NTT,NTT)