[POJ 1990] MooFest (树状数组)

POJ - 1990
N 头牛排成一行,每两头牛之间进行交谈
代价为 max(v[i],v[j])distance(i,j) ,问代价和为多少

很显然,如果暴力去做的话,是 O(n2) 的,必然会 T
所以我们要依照 v 的值来依次计算
即对 v 的值从小到大进行排序,然后对 v 从小到大计算
计算到排序后第 i 头牛时,前面的牛的 v 都小于等于它
但是我们无法确定哪些牛原来的位置在它之前,哪些牛在它之后,所以距离又变得难以计算
这时候就要动用树状数组了

我们需要两个量,一个之前的牛中,有多少头牛的位置在当前牛的左边
第二个是这些牛对原点的距离和是多少
我们边计算边更新 BIT,依照离原点的距离组织起 BIT
当处理完第 i 头牛后,将 BIT[1] 里第 i 头牛右边的所有牛都 +1 ,即在 i +1
将 BIT[0] 里第 i 头牛右边的所有牛都 +xi ,即在 i +xi

这样,我们就能统计出已经处理过的牛中,有多少头原来在当前牛的左边,有多少在右边
在左边的距离和是多少,在右边的距离和是多少
在左边的距离计算是当前的乘以左边的总数减去距离和,右边的是当前的距离乘以右边的总数减去距离和
cntl(xi)suml+sumrcntr(xi)
然后更新答案就好了

树状数组主要是动态维护一个前缀和,但其实它还有个副产品,就是同时增减一段区间内所有的数
例如我要给 [l,r] 的数全部加上 v ,做法就是
add(l,v); add(r+1,-v);
这样将会给 [l,maxn] 的所有数加上 v ,给 [r+1,maxn] 的所有数减去 v ,从而达到效果
不过注意这个加是从当前点到区间末尾的,稍微有点别扭
同样,查询 [l,r] 的数的和就是
query(r)query(l1)

#include <cstdio>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <map>
#include <set>
#include <queue>
using namespace std;
typedef pair<int,int> Pii;
typedef long long LL;
typedef unsigned long long ULL;
typedef double DBL;
typedef long double LDBL;
#define MST(a,b) memset(a,b,sizeof(a))
#define CLR(a) MST(a,0)
#define Pow2(a) (a*a)

const int maxn=2e4+10;
int N,X;
struct data{int id,x,v;} inpt[maxn];
int BIT[2][maxn];

bool Dcmp(data u, data v){return u.v<v.v;}
int lowbit(int x){return x&-x;};
void add(int,int,int);
int query(int,int);

int main()
{
    while(~scanf("%d", &N))
    {
        CLR(BIT);
        X=-1;
        for(int i=1; i<=N; i++)
        {
            inpt[i].id=i;
            scanf("%d%d", &inpt[i].v, &inpt[i].x);
            X=max(X, inpt[i].x);
        }
        sort(inpt+1,inpt+N+1,Dcmp);
        int sum=0;
        LL ans=0;
        for(int i=1; i<=N; i++)
        {
            int cntl,suml,cnt=i-1,cntr,sumr;
            cntl=query(1,inpt[i].x);
            suml=query(0,inpt[i].x);
            cntr=cnt-cntl;sumr=sum-suml;
            ans+=(LL)inpt[i].v*(cntl*inpt[i].x-suml+sumr-cntr*inpt[i].x);
            sum+=inpt[i].x;
            add(1,inpt[i].x,1);
            add(0,inpt[i].x,inpt[i].x);
        }
        cout << ans << '\n';
    }
    return 0;
}

void add(int flr,int np, int v)
{
    while(np<=X)
    {
        BIT[flr][np]+=v;
        np+=lowbit(np);
    }
}

int query(int flr,int np)
{
    int res=0;
    while(np)
    {
        res+=BIT[flr][np];
        np-=lowbit(np);
    }
    return res;
}

你可能感兴趣的:(poj)