Shintaro \text{Shintaro} Shintaro有 n n n条龙,第 i i i条龙的力量值为 x i x_i xi。现在 Shintaro \text{Shintaro} Shintaro想与这些龙交朋友。
Shintaro \text{Shintaro} Shintaro会使用以下两种魔法来平衡龙的力量值(使某些龙的力量值相等),以免与他交朋友的龙互相打架。
强化魔法:消耗 a a a点 m p mp mp,使某条龙的力量值增加 1 1 1点。
弱化魔法:消耗 b b b点 m p mp mp,使某条龙的力量值降低 1 1 1点。
在第 i i i次, Shintaro \text{Shintaro} Shintaro想与前 i i i条龙交朋友 ( 1 ≤ i ≤ n ) (1\leq i\leq n) (1≤i≤n)。我们有很多种使用魔法的方案,使前 i i i条龙力量值相等。请你找到消耗 m p mp mp点数最小的方案,并输出 m p mp mp点数。
1 ≤ n ≤ 1 0 5 , 1 ≤ a , b ≤ 1 0 4 , 1 ≤ x i ≤ 1 0 9 1\leq n\leq 10^5,1\leq a,b\leq 10^4,1\leq x_i\leq 10^9 1≤n≤105,1≤a,b≤104,1≤xi≤109
我们考虑将每条龙的力量值取到何值时代价最小。
对于每一个 i ( 1 ≤ i ≤ n ) i(1\leq i\leq n) i(1≤i≤n),将前 i i i条龙龙放在数轴上,取一个点 x x x,设在 x x x之前有 k k k条龙,则在 x x x之后有 i − k i-k i−k条龙。那么,我们将 x x x往右移一个单位,如果移动后 x x x的左边和右边的龙的数量不变,则代价会增加 k a − ( i − k ) b ka-(i-k)b ka−(i−k)b。当 k a − ( i − k ) b < 0 ka-(i-k)b<0 ka−(i−k)b<0时,显然将 x x x往右移一个单位是最优的,我们可以一直右移 x x x,直到 x x x在第 k + 1 k+1 k+1条龙对应的点上,然后继续判断新的 k k k是否满足 k a − ( i − k ) b < 0 ka-(i-k)b<0 ka−(i−k)b<0,再继续右移,直到不满足 k a − ( i − k ) b < 0 ka-(i-k)b<0 ka−(i−k)b<0,此时如果再右移肯定不优。
也就是说,我们要找到第一个 k k k使得 k a − ( i − k ) b ≥ 0 ka-(i-k)b\geq 0 ka−(i−k)b≥0,即 k ≥ i b a + b k\geq\dfrac{ib}{a+b} k≥a+bib。因为 k k k为整数,所以 k = ⌈ i b a + b ⌉ k=\lceil\dfrac{ib}{a+b}\rceil k=⌈a+bib⌉,那么取第 k k k条龙的力量值为所有龙最终的力量值即可使代价最小(这里的第 k k k条龙指数轴上从小到大的第 k k k条龙)。
我们把龙的力量值离散化一下,然后用权值线段树维护前 i i i条龙的力量值,还要维护线段树上的每个节点对应的区间中有多少条龙。每次用上面的方法求出 k k k,在线段树中求前 i i i条龙中第 k k k小的龙的力量值,再求出其他龙的力量值变为第 k k k小的龙的力量值的代价之和即可。
时间复杂度为 O ( n log n ) O(n\log n) O(nlogn)。
#include
#define lc k<<1
#define rc k<<1|1
using namespace std;
const int N=100000;
int n,k,v[N+5],num[N+5],hv[4*N+5];
long long a,b,sum,all,ans,tr[4*N+5];
void ch(int k,int l,int r,int x){
if(l==r&&l==x){
++hv[k];
tr[k]+=num[x];
return;
}
int mid=l+r>>1;
if(x<=mid) ch(lc,l,mid,x);
else ch(rc,mid+1,r,x);
hv[k]=hv[lc]+hv[rc];
tr[k]=tr[lc]+tr[rc];
}
int gtnum(int k,int l,int r,int x){
if(l==r) return l;
int mid=l+r>>1;
if(x<=hv[lc]) return gtnum(lc,l,mid,x);
else return gtnum(rc,mid+1,r,x-hv[lc]);
}
void find(int k,int l,int r,int x,int y){
if(x>y) return;
if(l>=x&&r<=y){
all+=hv[k];
sum+=tr[k];
return;
}
int mid=l+r>>1;
if(x<=mid) find(lc,l,mid,x,y);
if(y>mid) find(rc,mid+1,r,x,y);
}
int main()
{
// freopen("c.in","r",stdin);
// freopen("c.out","w",stdout);
scanf("%d%lld%lld",&n,&a,&b);
for(int i=1;i<=n;i++){
scanf("%d",&v[i]);num[i]=v[i];
}
sort(num+1,num+n+1);
int gs=unique(num+1,num+n+1)-num-1;
for(int i=1;i<=n;i++){
v[i]=lower_bound(num+1,num+gs+1,v[i])-num;
}
for(int i=1;i<=n;i++){
ch(1,1,n,v[i]);
k=(i*b+a+b-1)/(a+b);
int tmp=gtnum(1,1,n,k);
sum=all=0;find(1,1,n,1,tmp-1);
ans=(1ll*all*num[tmp]-sum)*a;
sum=all=0;find(1,1,n,tmp+1,n);
ans+=(sum-1ll*all*num[tmp])*b;
printf("%lld\n",ans);
}
return 0;
}