给一数组 a a a,记录下每个连续子数组的最大值
有 Q Q Q次询问,每个询问将问:记录的数中有多少个数是小于/大于/等于给出的 K K K的
既然要求子数组最大值
那么对于一个位置 i ( 1 ≤ i ≤ n ) i(1≤i≤n) i(1≤i≤n),可以先暴力往左右两边找,找到最远处使得这个区间内的最大值是 a [ i ] a[i] a[i]
可以用线段树记录最大值,然后区间查找
但是这个时间复杂度是 O ( n 2 l o g ( n ) ) O(n^2log(n)) O(n2log(n))的
考虑优化
对于最远处,可以用二分来找
优化成 O ( n l o g ( n 2 ) ) O(nlog(n^2)) O(nlog(n2))
那么 i i i对答案的贡献就是( i i i-最左端的下标+1)*(最右端的下标- i i i+1),记录下来 ( n u m [ i ] ) (num[i]) (num[i])
将 a [ i ] a[i] a[i]排序(同时更改 n u m [ i ] num[i] num[i])
记录 n u m num num的 前缀和 和 后缀和
对于每个询问
分类讨论
如果是小于
二分找到最大的下标 ( x ) (x) (x)使得 a [ x ] < K a[x]
如果是大于
二分找到最小的下标 ( x ) (x) (x)使得 a [ x ] > K a[x]>K a[x]>K,输出 x x x的后缀和
如果是等于
先二分找到最大的 x x x使得 a [ x ] = K a[x]=K a[x]=K,再二分找到最大的 y y y使得 a n s [ y ] < K ans[y]
#include
#include
#include
using namespace std;
int n,m,i,l,r,mid,bj;
long long s1,s2,x,ans,ans1,ans2,pre[100005],suf[100005],tree[400005];
char ch;
struct node
{
long long val,sum;
}a[100005];
bool cmp(node x,node y)
{
return x.val<y.val;
}
void biuld(int now,int l,int r)
{
if (l==r)
{
tree[now]=a[l].val;
return;
}
int mid=(l+r)>>1;
biuld(now<<1,l,mid);
biuld(now<<1|1,mid+1,r);
tree[now]=max(tree[now<<1],tree[now<<1|1]);
}
long long query(int now,int l,int r,int p,int q)
{
if (l>=p&&r<=q) return tree[now];
int mid=(l+r)>>1;
long long res;
res=0;
if (p<=mid) res=max(res,query(now<<1,l,mid,p,q));
if (q>mid) res=max(res,query(now<<1|1,mid+1,r,p,q));
return res;
}
int main()
{
freopen("jxthree.in","r",stdin);
freopen("jxthree.out","w",stdout);
scanf("%d%d",&n,&m);
for (i=1;i<=n;i++)
scanf("%lld",&a[i].val);
biuld(1,1,n);
for (i=1;i<=n;i++)
{
l=1;
r=i-1;
s1=i;
s2=i;
while (l<=r)
{
mid=(l+r)>>1;
if (query(1,1,n,mid,i-1)<a[i].val)
{
s1=mid;
r=mid-1;
}
else l=mid+1;
}
l=i+1;
r=n;
while (l<=r)
{
mid=(l+r)>>1;
if (query(1,1,n,i+1,mid)<=a[i].val)
{
s2=mid;
l=mid+1;
}
else r=mid-1;
}
a[i].sum=(long long)(i-s1+1)*(long long)(s2-i+1);
}
sort(a+1,a+n+1,cmp);
for (i=1;i<=n;i++)
pre[i]=pre[i-1]+a[i].sum;
for (i=n;i>=1;i--)
suf[i]=suf[i+1]+a[i].sum;
for (i=1;i<=m;i++)
{
ch=getchar();
while (ch!='>'&&ch!='='&&ch!='<') ch=getchar();
if (ch=='<') bj=1;
if (ch=='=') bj=2;
if (ch=='>') bj=3;
scanf("%lld",&x);
ans=0;
ans1=0;
ans2=0;
if (bj==1)
{
l=1;
r=n;
while (l<=r)
{
mid=(l+r)>>1;
if (a[mid].val<x)
{
ans=mid;
l=mid+1;
}
else r=mid-1;
}
printf("%lld\n",pre[ans]);
}
if (bj==2)
{
l=1;
r=n;
while (l<=r)
{
mid=(l+r)>>1;
if (a[mid].val<=x)
{
ans1=mid;
l=mid+1;
}
else r=mid-1;
}
l=1;
r=n;
while (l<=r)
{
mid=(l+r)>>1;
if (a[mid].val<x)
{
ans2=mid;
l=mid+1;
}
else r=mid-1;
}
if (a[ans1].val==x) printf("%lld\n",pre[ans1]-pre[ans2]);
else printf("0\n");
}
if (bj==3)
{
l=1;
r=n;
while (l<=r)
{
mid=(l+r)>>1;
if (a[mid].val>x)
{
ans=mid;
r=mid-1;
}
else l=mid+1;
}
printf("%lld\n",suf[ans]);
}
}
fclose(stdin);
fclose(stdout);
return 0;
}