对于序列A,它的逆序对数定义为满足i<j,且Ai>Aj的数对(i,j)的个数。给1到n的一个排列,按照某种顺序依次删除m个元素,你的任务是在每次删除一个元素之前统计整个序列的逆序对数。
N<=100000 M<=50000
题解:树状数组套线段树。
树状数组中维护的是逆序对的前缀和。
线段树中维护的是权值在树状数组当前点控制范围内区间中点的个数。
把数据离线,倒着搞,原本的删除操作就变成了插入操作。插入一个点那么这个新增节点的逆序对数就是数列中他前面的点中值比他大的(用前面的总点数-前面比他小的个数)+数列中他后面的节点中值比他小的。(统计的时候利用树状数组,就可以把比他小的所有权值通过lowbit的方式全部统计到)
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #define N 10000003 #define M 100003 #define LL long long using namespace std; int n,m,a[M],pos[M],pd[M],b[M],sum[N],c[M]; int rs[N],ls[N],tot,p,s[M],sz,root[M]; LL ans[M]; int lowbit(int x) { return (-x)&x; } int query(int now,int l,int r,int ll,int rr) { //if (delta[now]) pushdown(now,l,r); if (!now) return 0; if (l>=ll&&r<=rr) return sum[now]; int mid=(l+r)/2; int ans=0; if (ll<=mid) ans+=query(ls[now],l,mid,ll,rr); if (rr>mid) ans+=query(rs[now],mid+1,r,ll,rr); return ans; } void update(int x) { sum[x]=sum[ls[x]]+sum[rs[x]]; } void pointchange(int &k,int l,int r,int x) { if (!k) k=++sz; if (l==r) { sum[k]++; return ; } int mid=(l+r)/2; if (x<=mid) pointchange(ls[k],l,mid,x); else pointchange(rs[k],mid+1,r,x); update(k); } int add(int x) { int ans=0; for (int i=x;i;i-=lowbit(i)) ans+=s[i]; return ans; } int solve(int x,int l,int r) { int ans=0; for (int i=x;i;i-=lowbit(i)) ans+=query(root[i],1,n,l,r); return ans; } void insert(int x) { int t=x; int q=add(pos[t]); int a=0; int b=0; a=solve(x,1,pos[t]-1); if (pos[t]+1<=n) b=solve(x,pos[t]+1,n); int k=(q-a)+b; for (int i=x;i<=n;i+=lowbit(i)) { pointchange(root[i],1,n,pos[t]); c[i]+=k; } } LL sum1(int x) { LL ans=0; for (int i=x;i;i-=lowbit(i)) ans+=(LL)c[i]; return ans; } void change(int x) { for (int i=x;i<=n;i+=lowbit(i)) s[i]++; } int main() { scanf("%d%d",&n,&m); for (int i=1;i<=n;i++) scanf("%d",&a[i]),pos[a[i]]=i; for (int i=1;i<=m;i++) scanf("%d",&b[i]),pd[b[i]]=1; int maxn=0; for (int i=n;i>=1;i--) if (!pd[a[i]]) { insert(a[i]); change(pos[a[i]]); maxn=max(maxn,a[i]); } for (int i=m;i>=1;i--) { maxn=max(maxn,b[i]); insert(b[i]); change(pos[b[i]]); ans[i]=(LL)sum1(maxn); } for (int i=1;i<=m;i++) printf("%I64d\n",ans[i]); }