题目描述
对于序列 a,它的逆序对数定义为集合
{(i,j)| i < j ^ ai > aj }
中的元素个数。
现在给出 1∼n 的一个排列,按照某种顺序依次删除 m 个元素,你的任务是在每次删除一个元素之前统计整个序列的逆序对数。
输入格式
第一行包含两个整数 n 和 m,即初始元素的个数和删除的元素个数。
以下 n 行,每行包含一个 1∼n 之间的正整数,即初始排列。
接下来 m 行,每行一个正整数,依次为每次删除的元素。
输出格式
输出包含 m 行,依次为删除每个元素之前,逆序对的个数。
一道树状数组套权值线段树的题;但是好像比较卡常数;
这道题如果会用权值线段树求逆序对,这道题的思路大致就是那样;只不过权值线段树换成树套树;
代码:
#include
#define LL long long
#define pa pair
#define ls k<<1
#define rs k<<1|1
#define inf 0x3f3f3f3f
using namespace std;
const int N=200100;
const int M=10000000;
const LL mod=100000000;
int n,m,tot,L[N*400],R[N*400],sum[N*400],tr[N],po[N],a[N],vl[N],vr[N];
int cntl,cntr;
LL ans;
int lowbit(int p){ return p&(-p); }
int update(int pre,int l,int r,int pos,int s){
int rt=++tot;
L[rt]=L[pre],R[rt]=R[pre],sum[rt]=sum[pre]+s;
int d=(l+r)>>1;
if(l<r){
if(pos<=d) L[rt]=update(L[pre],l,d,pos,s);
else R[rt]=update(R[pre],d+1,r,pos,s);
}
return rt;
}
void add(int p,int q){
int x=a[p];
while(p<=n){
tr[p]=update(tr[p],1,n,x,q);
p+=lowbit(p);
}
}
void solve(int l,int r){
cntl=0,cntr=0;
while(l){
vl[++cntl]=tr[l];
l-=lowbit(l);
}
while(r){
vr[++cntr]=tr[r];
r-=lowbit(r);
}
}
int query(int l,int r,int ll,int rr){
if(l>=ll&&r<=rr){
int s=0;
for(int i=1;i<=cntr;i++) s+=sum[vr[i]];
for(int i=1;i<=cntl;i++) s-=sum[vl[i]];
return s;
}
int ans=0;
int d=(l+r)>>1;
if(l<r){
vector<int>v1,v2;
int ok=0;
if(ll<=d){
ok=1;
for(int i=1;i<=cntr;i++) v1.push_back(vr[i]),vr[i]=L[vr[i]];
for(int i=1;i<=cntl;i++) v2.push_back(vl[i]),vl[i]=L[vl[i]];
ans+=query(l,d,ll,rr);
}
if(rr>d){
if(ok){
for(int i=1;i<=cntr;i++) vr[i]=R[v1[i-1]];
for(int i=1;i<=cntl;i++) vl[i]=R[v2[i-1]];
}
else{
for(int i=1;i<=cntr;i++) vr[i]=R[vr[i]];
for(int i=1;i<=cntl;i++) vl[i]=R[vl[i]];
}
v1.clear(),v2.clear();
ans+=query(d+1,r,ll,rr);
}
}
return ans;
}
int main(){
// ios::sync_with_stdio(false);
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&a[i]),po[a[i]]=i;
for(int i=1;i<=n;i++) add(i,1);
for(int i=1;i<=n;i++){
solve(0,i-1);
ans+=(LL)query(1,n,a[i]+1,n);
}
for(int i=1;i<=m;i++){
int x;scanf("%d",&x);
printf("%lld\n",ans);
int pos=po[x];
solve(0,pos);
ans-=(LL)query(1,n,x+1,n);
solve(pos,n);
ans-=(LL)query(1,n,1,x-1);
add(pos,-1);
}
return 0;
}