Kattis aplusb A+B Problem FFT

A+B Problem

Given NN integers in the range [50000,50000][−50000,50000], how many ways are there to pick three integers aiaiajajakak, such that iijjkk are pairwise distinct and ai+aj=akai+aj=ak? Two ways are different if their ordered triples (i,j,k)(i,j,k)of indices are different.

Input

The first line of input consists of a single integer NN (1N2000001≤N≤200000). The next line consists of NN space-separated integers a1,a2,,aNa1,a2,…,aN.

Output

Output an integer representing the number of ways.

Sample Input 1 Sample Output 1
4
1 2 3 4
4
Sample Input 2 Sample Output 2
6
1 1 3 3 4 6
10
题意:给一些数,求满足 i,j,k两两不相等且ai+aj=ak条件的(i,j,k)个数。
思路:先求出ai+aj的所有可能的值的个数,这个记录下每种数的个数直接和自己FFT就可求出
然后把所有在数组中的数对应的个数相加就是答案。
注意:(1)输入的数可能是负数,要先全部加上一个数把他们全变成正数。
(2)FFT直接求出的是包括i==j的个数,求出后把i==j的情况减掉。
(3)数可能为0,为0会导致FFT算出的结果有(i,j,k)出现相同的情况.
比如:0+0=0,0+ai=ai,ai+0=ai。
很容易发现如果相加结果不为0只是会多 0+ai=ai,ai+0=ai这两种情况,直接减掉两倍的0的个数就好。
如果相加结果为0只是会多0+0=0的情况,这种情况和上面那种情况很像,就是把ai变成0,所以就是减掉两倍的(0的个数-1)就好,那个1是减掉那一个0变成ai的0。

#include 
#include 
#include 
#include 
using namespace std;
#define ll long long
//FFT模板开始  
const double PI = acos(-1.0);  
struct Virt  
{  
    double r, i;  
    Virt(double r = 0.0, double i = 0.0)  
    {  
        this->r = r;  
        this->i = i;  
    }  
    Virt operator + (const Virt &x)  
    {  
        return Virt(r + x.r, i + x.i);  
    }  
    Virt operator - (const Virt &x)  
    {  
        return Virt(r - x.r, i - x.i);  
    }  
    Virt operator * (const Virt &x)  
    {  
        return Virt(r * x.r - i * x.i, i * x.r + r * x.i);  
    }  
};  
//雷德算法--倒位序  
void Rader(Virt F[], int len)  
{  
    int j = len >> 1;  
    for(int i = 1; i < len - 1; i++)  
    {  
        if(i < j) swap(F[i], F[j]);  
        int k = len >> 1;  
        while(j >= k)  
        {  
            j -= k;  
            k >>= 1;  
        }  
        if(j < k) j += k;  
    }  
}  
//FFT实现  
void FFT(Virt F[], int len, int on)  
{  
    Rader(F, len);  
for(int h = 2; h <= len; h <<= 1) 
//分治后计算长度为h的DFT  
    {  
        Virt wn( cos(-on * 2 * PI / h), sin(-on * 2 * PI / h)); 
//单位复根e^(2*PI/m)用欧拉公式展开  
        for(int j = 0; j < len; j += h)  
        {  
            Virt w(1, 0);           
//旋转因子  
            for(int k = j; k < j + h / 2; k++)  
            {  
                Virt u = F[k];  
                Virt t = w * F[k + h / 2];  
                F[k] = u + t;    
//蝴蝶合并操作  
                F[k + h / 2] = u - t;  
                w = w * wn;     
//更新旋转因子  
            }  
        }  
    }  
    if(on == -1)  
        for(int i = 0; i < len; i++)  
            F[i].r /= len;  
}  
//求卷积  
void Conv(Virt a[], Virt b[], int n)  
{  
    FFT(a, n, 1);  
    FFT(b, n, 1);  
    for(int i = 0; i < n; i++)  
        a[i] = a[i] * b[i];  
    FFT(a, n, -1);  
}  
//FFT模板结束 
const int T = 50000;
const int MAXN = 200000;
ll num[MAXN+5];
ll cnt[MAXN+5];
Virt a[2*MAXN+5];
Virt b[2*MAXN+5];
ll ans[2*MAXN+5];
int main()
{
	int n;
	scanf("%d",&n);
	int zero = 0;
	for(int i = 0; i < n; i++)
	{
		scanf("%lld",&num[i]);
		if(num[i] == 0) zero++;
		cnt[num[i] + T]++;
	}
	int len = 1;
	while(len < MAXN) len <<= 1;
	for(int i = 0; i < MAXN; i++)
	{
		a[i] = b[i] = Virt(1.0 * cnt[i], 0.0);
	}
	Conv(a, b, len);
	for(int i = 0; i < len; i++)
	{
		ans[i] = (ll)(a[i].r + 0.5);
	}
	for(int i = 0; i < n; i++)
	{
		ans[(num[i] + T) * 2]--;
	}
	ll res = 0;
	for(int i = 0; i < n; i++)
	{
		res += ans[num[i]+T*2];
		res -= (zero - (num[i] == 0)) * 2;
	}
	printf("%lld\n",res);
	return 0;
}


你可能感兴趣的:(FFT)