Shintaro \text{Shintaro} Shintaro 有 n n n 条龙。 第 i i i 条龙的力量值是 x i x_i xi。现在 Shintaro \text{Shintaro} Shintaro 想与这些龙交朋友。
Shintaro \text{Shintaro} Shintaro 会使用以下两种魔法来平衡龙的力量值(使某些龙的力量值相等),以免与他交朋友的龙互相打架。
强化魔法:消耗 a a a 点 mp,使某条龙的力量值增加 1 1 1 点。
弱化魔法:消耗 b b b 点 mp,使某条龙的力量值降低 1 1 1 点。
在第 i i i 次, Shintaro \text{Shintaro} Shintaro 想与前 i i i 条龙交朋友 ( 1 ≤ i ≤ n ) (1≤i≤n) (1≤i≤n)。我们有很多种使用魔法的方案,使前 i i i 条龙力量值相等。请你找到消耗 mp 点数最小的方案,并输出 mp 点数。
n ≤ 1 0 5 n\le10^5 n≤105
1s/512MB
考虑随便取一个力量值,然后将它减少 1 1 1,在它之上的力量值所贡献的代价只会变大,而且变大的幅度会越来越大;在它之下的力量值所贡献的代价只会变小,而且变小的幅度会越来越小。发现两边的代价函数都是下凸函数,所以总的代价也是下凸函数,可以三分解决。
实现上,开一个权值线段树,维护当前区间有多少个元素与总和。对于一个力量值,求出它前面与后面的元素大小和总和,就可以知道代价了。
三分最好使用类似于二分的写法,不然可能 TLE。
然后提供一个卡常的方法:权值线段树一般都是开了一个结构体,里面有各种变量,把它们单独拿出来开数组,不用结构体,这样可以快很多。原因可能是结构体内部变量会对齐,导致空间变大,从而寻址时间变多,常数变大。
这样就可以卡过去了,时间复杂度 O ( n log 2 V ) O(n\log^2 V) O(nlog2V), V V V 是值域,这里是 1 0 9 10^9 109。
#include
using namespace std;
#define ll long long
constexpr int N=1e5+1,Inf=1e9;
constexpr ll INF=2e18;
int n,cnt=1,A,B,a[N],ls[N*30],rs[N*30],sz[N*30];
ll sum[N*30];
void insert(int &rt,int l,int r,int x)
{
if(!rt) rt=++cnt;
if(l==r){
sz[rt]++,sum[rt]+=x;
return;
}
int mid=l+r>>1;
if(x<=mid) insert(ls[rt],l,mid,x);
else insert(rs[rt],mid+1,r,x);
sz[rt]=sz[ls[rt]]+sz[rs[rt]];
sum[rt]=sum[ls[rt]]+sum[rs[rt]];
}
pair<ll,int> querypre(int x)
{
if(x<0) return make_pair(0,0);
if(x>Inf) return make_pair(sum[1],sz[1]);
int l=0,r=Inf,rt=1,Sz=0;
ll Sum=0;
while(l!=r){
int mid=l+r>>1;
if(sz[rs[rt]]&&x>mid) Sz+=sz[ls[rt]],Sum+=sum[ls[rt]],rt=rs[rt],l=mid+1;
else rt=ls[rt],r=mid;
}
return make_pair(Sum+sum[rt],Sz+sz[rt]);
}
pair<ll,int> querynxt(int x)
{
if(x<0) return make_pair(sum[1],sz[1]);
if(x>Inf) return make_pair(0,0);
int l=0,r=Inf,rt=1,Sz=0;
ll Sum=0;
while(l!=r){
int mid=l+r>>1;
if(sz[ls[rt]]&&x<=mid) rt=ls[rt],r=mid;
else Sz+=sz[ls[rt]],Sum+=sum[ls[rt]],rt=rs[rt],l=mid+1;
}
return make_pair(Sum+sum[rt],Sz+sz[rt]);
}
ll F(ll x)
{
if(x<0||x>Inf) return INF;
auto pre=querypre(x-1),nxt=querypre(x);
nxt.first=sum[1]-nxt.first;
nxt.second=sz[1]-nxt.second;
return (pre.second*x-pre.first)*A+(nxt.first-nxt.second*x)*B;
}
ll solve(int n)
{
int rt=1,l=0,r=Inf,czn=l;
insert(rt,0,Inf,a[n]);
while(l<=r){
int mid=l+r>>1;
ll fmid=F(mid),fmid1=F(mid+1);
if(fmid<fmid1) czn=mid,r=mid-1;
else czn=mid+1,l=mid+1;
}
return F(czn);
}
int main()
{
freopen("c.in","r",stdin);
freopen("c.out","w",stdout);
cin.tie(0)->sync_with_stdio(0);
cin>>n>>A>>B;
for(int i=1;i<=n;i++) cin>>a[i];
for(int i=1;i<=n;i++) cout<<solve(i)<<"\n";
}