题目链接:https://nanti.jisuanke.com/t/38228
问题描述:求所有区间中区间和乘以区间最小值的最大值问题。
数据范围:1
大体思想:对于每一个元素,将它作为一个区间的最小值,然后找到满足这样条件的左右区间,使得该元素乘上这个区间和最大,最后遍历所有元素,取其最大值即可。
首先考虑当遍历到元素 a[i] >0的情况,从 a[i] 向左遍历直到出现元素小于 a[i] 或到数组的最左端停止,记为左区间l;再从a[i]向右遍历直到出现元素小于a[i]或数组的最右端,记为右区间r。因为a[i]>0,所以区间[l, r]中元素也都大于等于零,且此区间和就是a[i]对应的最大区间和。此时单调栈就能帮助我们找到对于每一个元素a[i](a[i]>0)对应的左右区间l, r。正向遍历找出l,反向遍历找出r.
当元素a[i] < 0时,我们依旧能先用单调栈找出对应的左右区间l,r,但是此时的区间[l,r]的和不一定就是满足a[i]对应下的最小区间
(因为a[i]<0,所以找对应的最小区间和),但是我们能发现最小的区间和肯定在区间[l,r]中,这是我们使用线段树维护一个包含a[i]节点的区间最大连续字段和(在建树时将用a[i]的相反数,就可以得到最小区间和)
当元素a[i]=0时,直接就是0。
代码如下:
#include
using namespace std;
#define ll long long
#define inf 0x3f3f3f3f
#define mes(a, val) memset(a, val, sizeof a)
#define mec(b, a) memcpy(b, a, sizeof a)
const int maxn = 5e5+100;
int n;
ll pre[maxn];
struct node{
int val, s, e;
};
vectorv(maxn);
struct SegmentTree{
int l, r;
///lmax表示从左区间向右的最大连续字段和,rmax表示从右区间向左的最大连续字段和
///sum表示区间和,dat表示区间最大连续字段和
ll lmax, rmax, sum, dat;
}tree[4*maxn];
struct Data{
ll lmax, rmax, sum, dat;
};
void push(int p){
tree[p].lmax = max(tree[p<<1].lmax, tree[p<<1].sum+tree[p<<1|1].lmax);
tree[p].rmax = max(tree[p<<1|1].rmax, tree[p<<1|1].sum+tree[p<<1].rmax);
tree[p].sum = tree[p<<1].sum + tree[p<<1|1].sum;
tree[p].dat = max(max(tree[p<<1].dat, tree[p<<1|1].dat), tree[p<<1].rmax+tree[p<<1|1].lmax);
}
void build(int l, int r, int p){
tree[p].l = l; tree[p].r = r;
if(l == r){
tree[p].lmax = tree[p].rmax = tree[p].sum = tree[p].dat = -v[l].val;
return;
}
int mid = (l + r) >> 1;
build(l, mid, p<<1);
build(mid+1, r, p<<1|1);
push(p);
}
///注意返回的是一个Data类型
Data query(int l, int r, int p){
if(l <= tree[p].l && r >= tree[p].r)return Data{tree[p].lmax, tree[p].rmax, tree[p].sum, tree[p].dat };
int mid = (tree[p].l + tree[p].r)>>1;
if(mid >= r) return query(l, r, p<<1);
if(mid < l)return query(l, r, p<<1|1);
Data a = query(l, r, p<<1), b = query(l, r, p<<1|1), c;
c.sum = a.sum+b.sum;
c.lmax = max(a.lmax, a.sum+b.lmax);
c.rmax = max(b.rmax, b.sum+a.rmax);
c.dat = max(max(a.dat, b.dat), a.rmax+b.lmax);
return c;
}
int main()
{
mes(pre, 0);
scanf("%d", &n);
pre[0] = 0;
for(int i = 1; i <= n; i ++){
scanf("%d", &v[i].val);
pre[i] = pre[i-1]+v[i].val;
v[i].s = i; v[i].e = i;
}
stacks;
int i = 0;
while(i <= n){
if(s.empty() || v[i].val > v[s.top()].val){
s.push(i);
i ++;
}
else {
v[i].s = v[s.top()].s;
s.pop();
}
}
while(!s.empty()){
s.pop();
}
i = n;
while(i >= 1){
if(s.empty() || v[i].val > v[s.top()].val){
s.push(i);
i --;
}
else {
v[i].e = v[s.top()].e;
s.pop();
}
}
ll ans = -inf;
build(1, n, 1);
for(int i = 1; i <= n; i ++){
ll val = (ll)v[i].val;
int l = v[i].s, r = v[i].e;
if(val > 0){
ans = max(ans, val*(pre[r]-pre[l-1]));
}
else if(val == 0) {
ans = max(ans, (ll)0);
}
else if(val < 0){
Data la, lb;
la = query(l, i, 1);
ll cnt1 = 0, cnt2 = 0;
if(r == i){
cnt1 = min(-1*la.rmax, (ll)0);
cnt2 = 0;
}
else {
lb = query(i+1, r, 1);
cnt1 = min(-1*la.rmax, (ll)0);
cnt2 = min(-1*lb.lmax, (ll)0);
}
ll cnt = cnt1+cnt2;
ans = max(ans, val*cnt);
}
// ans = max(ans, v[i].val*(pre[v[i].e]-pre[v[i].s-1]));
}
cout<