[agc023E]Inversions

题目大意

给你一个大小为n的数组a[1..n]。
一个排列P[1..n]是合法的,当且仅当对于所有i=1~n,P[i]<=a[i]。
问你所有合法排列的逆序对个数。
n<=2e5。

分析

我们先考虑总合法排列数怎么算。
设cnt[i]表示a的值大于等于i的数量。
我们考虑从大到小填数,那么一个位置能够填了,之后也一定能够填,就没有后效性了。
数量就是 i=1..ncnt[i](ni) ∏ i = 1.. n c n t [ i ] − ( n − i )
考虑两个位置 i,j,i<j i , j , 设 i < j ,如果a[i]<=a[j],考虑p[i]和p[j]的所有不同情况,那么显然p[j]>a[i]的那些情况不是逆序对,所以我们可以先直接把a[j]变成a[i],然后再考虑剩下的情况。那么一定有一半的情况是 p[i]>p[j] p [ i ] > p [ j ] ,另一半 p[i]<p[j] p [ i ] < p [ j ] (因为a[i]=a[j]了)。对总逆序对的贡献就是a[j]强制等于a[i]时,合法排列数/2。
而当 a[i]>a[j] a [ i ] > a [ j ] 的时候,可以这样算:总合法排列数-(a[i]强制等于a[j]时的合法排列数/2)。
那么这就是一个 O(n2) O ( n 2 ) 的做法了。
考虑优化。
我们按权值从大到小做,拿下标为原数组下标的线段树维护一些乘积的和即可。
O(nlogn)

代码

#include 
#include
#include
#include
using namespace std;
#define fo(i,j,k) for(i=j;i<=k;i++)
#define fd(i,j,k) for(i=j;i>=k;i--)
#define cmax(a,b) (a=(a>b)?a:b)
#define cmin(a,b) (a=(a
typedef long long ll;
typedef double db;
const int N=2e5+5,mo=1e9+7;
struct rec
{
    int sum,cnt;
}tr[N*4],tmp;
int tag[N*4];
rec operator +(rec a,rec b)
{
    return {(a.sum+b.sum)%mo,a.cnt+b.cnt};
}
int n,a[N],cnt[N],prod,i,x,fac,d[N],ans,rev2;
int tt,b[N],nxt[N],fst[N];
void cr(int x,int y)
{
    tt++;
    b[tt]=y;
    nxt[tt]=fst[x];
    fst[x]=tt;
}
int ksm(int x,int y)
{
    int ret=1;
    while (y)
    {
        if (y&1) ret=1ll*ret*x%mo;
        y>>=1;
        x=1ll*x*x%mo;
    }
    return ret;
}
void dw(int x,int s)
{
    if (tag[x]==1) return ;
    if (!s) tag[x*2]=1ll*tag[x*2]*tag[x]%mo,tag[x*2+1]=1ll*tag[x*2+1]*tag[x]%mo;
    tr[x].sum=1ll*tr[x].sum*tag[x]%mo;
    tag[x]=1;
}
rec get(int x,int l,int r,int i,int j)
{
    if (l==i&&r==j) return tr[x];
    int m=l+r>>1;
    dw(x*2,l==m);
    dw(x*2+1,m+1==r);
    rec ret={0,0};
    if (i<=m) ret=ret+get(x*2,l,m,i,min(m,j));
    if (mx*2+1,m+1,r,max(i,m+1),j);
    return ret;
}
void change(int x,int l,int r,int p,int v)
{
    if (l==r) 
    {
        tr[x].cnt=1,tr[x].sum=v;
        return ;
    }
    int m=l+r>>1;
    dw(x*2,l==m);
    dw(x*2+1,m+1==r);
    if (p<=m) change(x*2,l,m,p,v);
    else change(x*2+1,m+1,r,p,v);
    tr[x]=tr[x*2]+tr[x*2+1];
}
int main()
{
    freopen("t10.in","r",stdin);
    freopen("t10.out","w",stdout);
    scanf("%d",&n);
    fo(i,1,n) 
    {
        scanf("%d",a+i);
        cr(a[i],i);
        cnt[a[i]]++;
    }
    fd(i,n,1) cnt[i]+=cnt[i+1];
    prod=1;
    fo(i,1,n) prod=1ll*prod*(cnt[i]-(n-i))%mo;
    fo(i,1,n*4) tag[i]=1;
    rev2=ksm(2,mo-2);
    fd(x,n,1)
    {
        d[0]=0;
        for(int p=fst[x];p;p=nxt[p]) d[++d[0]]=b[p];
        fac=ll(cnt[x]-1-(n-x))*ksm(cnt[x]-(n-x),mo-2)%mo;
        fo(i,1,d[0])
        {
            tmp=get(1,1,n,d[i],n);
            ans=(ans+1ll*tmp.sum*rev2)%mo;
            tmp=get(1,1,n,1,d[i]);
            ans=(ans+1ll*tmp.cnt*prod-1ll*tmp.sum*rev2)%mo;
        }
        ans=(ans+1ll*d[0]*(d[0]-1)/2%mo*prod%mo*rev2)%mo;
        fo(i,1,d[0]) change(1,1,n,d[i],prod);
        tag[1]=1ll*tag[1]*fac%mo;
        dw(1,1==n);
    }
    if (ans<0) ans+=mo;
    printf("%d",ans);
}

你可能感兴趣的:(计数类问题,线段树,树状数组)