POJ 1990 MooFest 树状数组

一、题目大意

我们有N头牛,需要两两之间相互通讯,其中每头牛对应一个坐标x和一个听力v,设第i头牛的听力为v(i),坐标为x(i)(1<=x<=20000),

已知牛i和牛j相互通讯需要的音量为 max(v(i),v(j))*|x(i)-x(j)|,求出N(N-1)对通讯的音量的总和。

二、解题思路

我们首先看下,牛i和牛j相互通讯需要的音量为 max(v(i),v(j))*|x(i)-x(j)|,那么如果我们对所有的牛根据v来排序,保证对于 i < j ,v[i] <= v[j],之后从第一个开始循环,计算 i 和 i 前面的所有牛的坐标差之和  * v[i]即可。 

然后我们可以记录2个树状数组,其中一个bit用来存坐标,另一个用bitCnt来计数,每一次循环执行如下的事情

1)计算bit的前 x[i]项和leftSum,和 前 262144 的和allSum(所有元素)

2)计算bitCnt前 x[i]项和leftCnt,和 前 262144 的和allCnt(所有个数)

3)其中leftCnt就是牛i左边的牛的坐标和,例如3头牛 1 2 5,那么它左边的坐标和为1+2=3,然后leftCnt就是牛i左边牛的数量,即2,那么 i 和左边的坐标差=(5-1)+(5-1)=5 * 2 - (2 + 1)

所以左边的牛的音量 = (x[i] * leftCnt -  leftSum )*v[i]

然后allCnt-leftCnt就是右边的数量rightCnt,allSum-leftSum就是右边的数量rightSum

同理右边的牛的音量 = ( rightSum - x[i]*rightCnt )*v[i]

之后把左右两边音量的和加到ans里即可

4)将bit的x[i]位+x[i],将bitCnt的x[i]位+1

bit[ x[ i ] ]+=x[i]

bitCnt[ x[ i ] ]+=1

同步更新两棵树的父节点...

(备注:本题目中坐标乘以数量然后不断求和的过程中 n * (n-1) * 20000可能会大于int32,注意给结果开long long)

三、代码

#include 
#include 
using namespace std;
typedef long long ll;
typedef pair P;
P num[262150];
int bit[262150], n_, n, bitCnt[262150];
ll ans = 0LL;
void input()
{
    scanf("%d", &n_);
    for (int i = 1; i <= n_; i++)
    {
        scanf("%d%d", &num[i].first, &num[i].second);
    }
    sort(num + 1, num + (1 + n_));
}
void init()
{
    n = 262144;
    for (int i = 0; i <= n; i++)
    {
        bit[i] = 0;
        bitCnt[i] = 0;
    }
}
void update(int r, int v)
{
    if (r <= 0)
    {
        return;
    }
    for (int i = r; i <= n; i = i + (i & (-i)))
    {
        bit[i] = bit[i] + v;
    }
}
void updateCnt(int r, int v)
{
    if (r <= 0)
    {
        return;
    }
    for (int i = r; i <= n; i = i + (i & (-i)))
    {
        bitCnt[i] = bitCnt[i] + v;
    }
}
int query(int r)
{
    int sum = 0;
    for (int i = r; i > 0; i = i - (i & (-i)))
    {
        sum = sum + bit[i];
    }
    return sum;
}
int queryCnt(int r)
{
    int sum = 0;
    for (int i = r; i > 0; i = i - (i & (-i)))
    {
        sum = sum + bitCnt[i];
    }
    return sum;
}
void solve()
{
    for (int i = 1; i <= n_; i++)
    {
        int leftSum = query(num[i].second);
        int allSum = query(n);
        int leftCnt = queryCnt(num[i].second);
        int allCnt = queryCnt(n);
        ans = ans + (((ll)((leftCnt * num[i].second) - leftSum)) * ((ll)num[i].first));
        ans = ans + (((ll)((allSum - leftSum) - ((allCnt - leftCnt) * num[i].second))) * ((ll)num[i].first));
        update(num[i].second, num[i].second);
        updateCnt(num[i].second, 1);
    }
}
int main()
{
    input();
    init();
    solve();
    printf("%lld\n", ans);
    return 0;
}

你可能感兴趣的:(算法,数据结构)