给你一个大小为n的数组a[1..n]。
一个排列P[1..n]是合法的,当且仅当对于所有i=1~n,P[i]<=a[i]。
问你所有合法排列的逆序对个数。
n<=2e5。
我们先考虑总合法排列数怎么算。
设cnt[i]表示a的值大于等于i的数量。
我们考虑从大到小填数,那么一个位置能够填了,之后也一定能够填,就没有后效性了。
数量就是 ∏i=1..ncnt[i]−(n−i) ∏ i = 1.. n c n t [ i ] − ( n − i )
考虑两个位置 i,j,设i<j i , j , 设 i < j ,如果a[i]<=a[j],考虑p[i]和p[j]的所有不同情况,那么显然p[j]>a[i]的那些情况不是逆序对,所以我们可以先直接把a[j]变成a[i],然后再考虑剩下的情况。那么一定有一半的情况是 p[i]>p[j] p [ i ] > p [ j ] ,另一半 p[i]<p[j] p [ i ] < p [ j ] (因为a[i]=a[j]了)。对总逆序对的贡献就是a[j]强制等于a[i]时,合法排列数/2。
而当 a[i]>a[j] a [ i ] > a [ j ] 的时候,可以这样算:总合法排列数-(a[i]强制等于a[j]时的合法排列数/2)。
那么这就是一个 O(n2) O ( n 2 ) 的做法了。
考虑优化。
我们按权值从大到小做,拿下标为原数组下标的线段树维护一些乘积的和即可。
O(nlogn)
#include
#include
#include
#include
using namespace std;
#define fo(i,j,k) for(i=j;i<=k;i++)
#define fd(i,j,k) for(i=j;i>=k;i--)
#define cmax(a,b) (a=(a>b)?a:b)
#define cmin(a,b) (a=(a
typedef long long ll;
typedef double db;
const int N=2e5+5,mo=1e9+7;
struct rec
{
int sum,cnt;
}tr[N*4],tmp;
int tag[N*4];
rec operator +(rec a,rec b)
{
return {(a.sum+b.sum)%mo,a.cnt+b.cnt};
}
int n,a[N],cnt[N],prod,i,x,fac,d[N],ans,rev2;
int tt,b[N],nxt[N],fst[N];
void cr(int x,int y)
{
tt++;
b[tt]=y;
nxt[tt]=fst[x];
fst[x]=tt;
}
int ksm(int x,int y)
{
int ret=1;
while (y)
{
if (y&1) ret=1ll*ret*x%mo;
y>>=1;
x=1ll*x*x%mo;
}
return ret;
}
void dw(int x,int s)
{
if (tag[x]==1) return ;
if (!s) tag[x*2]=1ll*tag[x*2]*tag[x]%mo,tag[x*2+1]=1ll*tag[x*2+1]*tag[x]%mo;
tr[x].sum=1ll*tr[x].sum*tag[x]%mo;
tag[x]=1;
}
rec get(int x,int l,int r,int i,int j)
{
if (l==i&&r==j) return tr[x];
int m=l+r>>1;
dw(x*2,l==m);
dw(x*2+1,m+1==r);
rec ret={0,0};
if (i<=m) ret=ret+get(x*2,l,m,i,min(m,j));
if (mx *2+1,m+1,r,max(i,m+1),j);
return ret;
}
void change(int x,int l,int r,int p,int v)
{
if (l==r)
{
tr[x].cnt=1,tr[x].sum=v;
return ;
}
int m=l+r>>1;
dw(x*2,l==m);
dw(x*2+1,m+1==r);
if (p<=m) change(x*2,l,m,p,v);
else change(x*2+1,m+1,r,p,v);
tr[x]=tr[x*2]+tr[x*2+1];
}
int main()
{
freopen("t10.in","r",stdin);
freopen("t10.out","w",stdout);
scanf("%d",&n);
fo(i,1,n)
{
scanf("%d",a+i);
cr(a[i],i);
cnt[a[i]]++;
}
fd(i,n,1) cnt[i]+=cnt[i+1];
prod=1;
fo(i,1,n) prod=1ll*prod*(cnt[i]-(n-i))%mo;
fo(i,1,n*4) tag[i]=1;
rev2=ksm(2,mo-2);
fd(x,n,1)
{
d[0]=0;
for(int p=fst[x];p;p=nxt[p]) d[++d[0]]=b[p];
fac=ll(cnt[x]-1-(n-x))*ksm(cnt[x]-(n-x),mo-2)%mo;
fo(i,1,d[0])
{
tmp=get(1,1,n,d[i],n);
ans=(ans+1ll*tmp.sum*rev2)%mo;
tmp=get(1,1,n,1,d[i]);
ans=(ans+1ll*tmp.cnt*prod-1ll*tmp.sum*rev2)%mo;
}
ans=(ans+1ll*d[0]*(d[0]-1)/2%mo*prod%mo*rev2)%mo;
fo(i,1,d[0]) change(1,1,n,d[i],prod);
tag[1]=1ll*tag[1]*fac%mo;
dw(1,1==n);
}
if (ans<0) ans+=mo;
printf("%d",ans);
}