有一个长度为 n n n的序列 a a a。
定义一个合法的二元组 ( i , j ) (i,j) (i,j)满足 i , j i,j i,j为整数且 i ≤ j i\leq j i≤j,合法二元组的分数为 a i − a j a_i-a_j ai−aj。定义一个合法二元组 ( i , j ) (i,j) (i,j)在区间 [ l , r ] [l,r] [l,r]内,当且仅当 l ≤ i , j ≤ r l\leq i,j\leq r l≤i,j≤r。
有 m m m次操作,每次操作有两个类型:
1 l r x
:表示将 a a a中第 l l l个位置到第 r r r个位置加 x x x2 l r k
:表示询问选出 k k k个在区间 [ l , r ] [l,r] [l,r]内不同的二元组 ( i , j ) (i,j) (i,j)的最大分数和是多少1 ≤ n , m ≤ 1 0 5 , − 1 0 6 ≤ a i , x ≤ 1 0 6 , ∑ k ≤ 3 × 1 0 5 1\leq n,m\leq 10^5,-10^6\leq a_i,x\leq 10^6,\sum k\leq 3\times 10^5 1≤n,m≤105,−106≤ai,x≤106,∑k≤3×105
数据保证 [ l , r ] [l,r] [l,r]中至少有 k k k个合法二元组。
我们先考虑 k = 1 k=1 k=1的情况。也就是说,要在区间 [ l , r ] [l,r] [l,r]中找到一个合法二元组 ( i , j ) (i,j) (i,j)使得 a i − a j a_i-a_j ai−aj最大。
可以用线段树解决。线段树维护三个值:区间最大值 m x mx mx,区间最小值 m n mn mn,区间最大答案 a n s ans ans。
那么,对于一个点 p p p, a n s p = max ( a n s l s , a n s r s , m x l s − m n r s ) ans_p=\max(ans_{ls},ans_{rs},mx_{ls}-mn_{rs}) ansp=max(ansls,ansrs,mxls−mnrs)
当整体加上某个值的时候,若一个区间内都需要加,则最大值 m x mx mx和最小值 m n mn mn增加,但因为答案是两个 a a a的差,所以最大答案 a n s ans ans不变。
这样,我们就解决了 k = 1 k=1 k=1的情况。。
我们可以用一个结构体来表示 i ∈ [ l 1 , r 1 ] i\in[l_1,r_1] i∈[l1,r1]且 j ∈ [ l 2 , r 2 ] j\in[l_2,r_2] j∈[l2,r2]并存储其最优的一组合法二元组 ( i , j ) (i,j) (i,j)以及其分数,我们可以把这些 ( i , j ) (i,j) (i,j)看成一个矩形。每次处理完一个矩形,就将这个矩形分成不包括最优合法二元组多个矩形继续来求,此时能保证原矩形分成的多个矩形中每个矩形的最优合法二元组的分数都不超过原矩形的分数。可以用优先队列,按最优合法二元组的分数从大到小排序,选了 k k k个最优合法二元组并计算贡献之后就不需要再选了。
一般的矩形是不好求解答案的,不过有两种矩形可以:
显然题意要询问的就是第一种矩形。
那能不能在每次分裂的时候都将矩形分裂为这两种矩形呢?是可以的。
对于第一种矩形, i , j ∈ [ l , r ] i,j\in[l,r] i,j∈[l,r],假设最优解位于 ( x , y ) (x,y) (x,y),我们可以将这个矩形分为如下几种矩形:
对于第二种矩形, i ∈ [ l 1 , r 1 ] , j ∈ [ l 2 , r 2 ] i\in[l_1,r_1],j\in[l_2,r_2] i∈[l1,r1],j∈[l2,r2],假设最优解位于 ( x , y ) (x,y) (x,y),我们可以将这个矩形分为如下几种矩形:
这样,问题就解决了。
时间复杂度为 O ( n log n + ( ∑ k ) ( log n + log ( ∑ k ) ) ) O(n\log n+(\sum k)(\log n+\log(\sum k))) O(nlogn+(∑k)(logn+log(∑k)))。
#include
#define lc k<<1
#define rc k<<1|1
using namespace std;
const int N=100000;
int n,m,mxl,wl,wr,mxw[4*N+5],mnw[4*N+5],cl[4*N+5],cr[4*N+5];
long long ans,mxd,zl,zr,a[N+5],mx[4*N+5],mn[4*N+5],tr[4*N+5],ly[4*N+5];
struct node{
int l1,r1,l2,r2,z,x,y;
long long vl;
bool operator<(const node ax)const{
return vl<ax.vl;
}
};
priority_queue<node>q;
void up(int k){
if(mx[lc]>mx[rc]) mx[k]=mx[lc],mxw[k]=mxw[lc];
else mx[k]=mx[rc],mxw[k]=mxw[rc];
if(mn[lc]<mn[rc]) mn[k]=mn[lc],mnw[k]=mnw[lc];
else mn[k]=mn[rc],mnw[k]=mnw[rc];
tr[k]=max(max(tr[lc],tr[rc]),mx[lc]-mn[rc]);
if(tr[k]==tr[lc]) cl[k]=cl[lc],cr[k]=cr[lc];
else if(tr[k]==tr[rc]) cl[k]=cl[rc],cr[k]=cr[rc];
else cl[k]=mxw[lc],cr[k]=mnw[rc];
}
void build(int k,int l,int r){
if(l==r){
mx[k]=mn[k]=a[l];
mxw[k]=mnw[k]=cl[k]=cr[k]=l;
tr[k]=0;
return;
}
int mid=l+r>>1;
build(lc,l,mid);
build(rc,mid+1,r);
up(k);
}
void down(int k){
mx[lc]+=ly[k];
mn[lc]+=ly[k];
ly[lc]+=ly[k];
mx[rc]+=ly[k];
mn[rc]+=ly[k];
ly[rc]+=ly[k];
ly[k]=0;
}
void ch(int k,int l,int r,int x,int y,int v){
if(l>=x&&r<=y){
mx[k]+=v;
mn[k]+=v;
ly[k]+=v;
return;
}
if(ly[k]) down(k);
int mid=l+r>>1;
if(x<=mid) ch(lc,l,mid,x,y,v);
if(y>mid) ch(rc,mid+1,r,x,y,v);
up(k);
}
void find(int k,int l,int r,int x,int y){
if(l>=x&&r<=y){
ans=max(max(ans,tr[k]),mxd-mn[k]);
if(ans==tr[k]) wl=cl[k],wr=cr[k];
else if(ans==mxd-mn[k]) wl=mxl,wr=mnw[k];
if(mx[k]>mxd) mxd=mx[k],mxl=mxw[k];
return;
}
if(ly[k]) down(k);
int mid=l+r>>1;
if(x<=mid) find(lc,l,mid,x,y);
if(y>mid) find(rc,mid+1,r,x,y);
}
void gtmx(int k,int l,int r,int x,int y){
if(l>=x&&r<=y){
if(zl<mx[k]) zl=mx[k],wl=mxw[k];
return;
}
if(ly[k]) down(k);
int mid=l+r>>1;
if(x<=mid) gtmx(lc,l,mid,x,y);
if(y>mid) gtmx(rc,mid+1,r,x,y);
}
void gtmn(int k,int l,int r,int x,int y){
if(l>=x&&r<=y){
if(zr>mn[k]) zr=mn[k],wr=mnw[k];
return;
}
if(ly[k]) down(k);
int mid=l+r>>1;
if(x<=mid) gtmn(lc,l,mid,x,y);
if(y>mid) gtmn(rc,mid+1,r,x,y);
}
void pls(int l1,int r1,int l2,int r2,int z){
if(z==1){
ans=mxd=-1e18;
find(1,1,n,l1,r1);
q.push((node){l1,r1,l2,r2,z,wl,wr,ans});
}
else{
zl=-1e18;zr=1e18;
gtmx(1,1,n,l1,r1);gtmn(1,1,n,l2,r2);
q.push((node){l1,r1,l2,r2,z,wl,wr,zl-zr});
}
}
void work(int vl,int vr,int x){
pls(vl,vr,vl,vr,1);
long long sum=0;
for(int o=1;o<=x;o++){
node t=q.top();q.pop();
sum+=t.vl;
int x=t.x,y=t.y;
if(t.z==1){
int l=t.l1,r=t.r1;
if(x>l) pls(l,x-1,l,x-1,1);
if(x>l) pls(l,x-1,x,r,2);
if(x!=y) pls(x,x,x,x,1);
if(x<y-1) pls(x,x,x+1,y-1,2);
if(y<r) pls(x,x,y+1,r,2);
if(x<r) pls(x+1,r,x+1,r,1);
}
else{
if(x>t.l1) pls(t.l1,x-1,t.l2,t.r2,2);
if(t.l2<y) pls(x,x,t.l2,y-1,2);
if(y<t.r2) pls(x,x,y+1,t.r2,2);
if(x<t.r1) pls(x+1,t.r1,t.l2,t.r2,2);
}
}
printf("%lld\n",sum);
while(!q.empty()) q.pop();
}
int main()
{
freopen("D.in","r",stdin);
freopen("D.out","w",stdout);
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++){
scanf("%lld",&a[i]);
}
build(1,1,n);
for(int o=1,tp,l,r,x;o<=m;o++){
scanf("%d%d%d%d",&tp,&l,&r,&x);
if(tp==1) ch(1,1,n,l,r,x);
else work(l,r,x);
}
return 0;
}