题面:
对于序列A,它的逆序对数定义为满足 i<j ,且 Ai>Aj 的数对 (i,j) 的个数。给1到n的一个排列,按照某种顺序依次删除m个元素,你的任务是在每次删除一个元素之前统计整个序列的逆序对数。
对于已经存在的序列进行删除不是很好操作,不妨先倒过来想,假设我们将这些数空缺,然后倒序加入序列中是否可行,很明显可以,那么剩下的问题就是对于当前加入的书产生了多少个逆序对,对于新添入一个数来说,产生的逆序对一定是当前序列中1~x比它大的和x~n比它小的个数和,很明显这个问题用主席树可以解决,由于下标的顺序问题,不妨用树状数组套一下,建立下标树状数组,这样问题就迎刃而解了。
如果对于每个操作都开logn个结点,那么理论上的空间为 m∗log2(n) ,大约 2∗107 ,此时已经爆了内存。。。但是考虑到我们不需要回溯历史版本,所以稍微处理下,实际结点不需要那么多。
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
//#pragma comment(linker, "/STACK:1024000000,1024000000");
using namespace std;
#define INF 0x3f3f3f3f
#define maxn 100005
int n,m,cnt;
long long sum;
long long ans[maxn];
bool vis[maxn];
int b[maxn];
int a[maxn],indx[maxn];
int root[maxn];
void init()
{
memset(root,0,sizeof root);
memset(vis,false,sizeof vis);
cnt=0;
sum=0;
}
struct node
{
int l,r;
int sum;
node()
{
l=r=sum=0;
}
} t[90*maxn];
void update(int &rt,int pre,int pos,int l,int r)
{
if(!rt) t[rt=++cnt]=t[pre];
t[rt].sum++;
if(l==r&&r==pos)
{
t[rt].sum=1;
return ;
}
int mid=l+r>>1;
if(pos<=mid) update(t[rt].l,t[pre].l,pos,l,mid);
else update(t[rt].r,t[pre].r,pos,mid+1,r);
}
int query1(int rt,int pos,int l,int r)
{
if(l==r&&r==pos)
{
return 0;
}
int mid=l+r>>1;
if(pos<=mid) return query1(t[rt].l,pos,l,mid)+t[t[rt].r].sum;
else return query1(t[rt].r,pos,mid+1,r);
}
int query2(int rt,int pos,int l,int r)
{
if(l==r&&r==pos)
{
return 0;
}
int mid=l+r>>1;
if(pos<=mid) return query2(t[rt].l,pos,l,mid);
else return query2(t[rt].r,pos,mid+1,r)+t[t[rt].l].sum;
}
int lowbit(int x)
{
return x&-x;
}
void add(int x,int pos)
{
for(; x<=n; x+=lowbit(x)) update(root[x],root[x],pos,1,n);
}
long long ask(int x,int pos)
{
long long temp=0;
for(int i=x; i>0; i-=lowbit(i)) temp+=query1(root[i],pos,1,n);
for(int i=n; i>0; i-=lowbit(i)) temp+=query2(root[i],pos,1,n);
for(int i=x; i>0; i-=lowbit(i)) temp-=query2(root[i],pos,1,n);
return temp;
}
int main()
{
scanf("%d%d",&n,&m);
init();
for(int i=1; i<=n; i++)
{
scanf("%d",&a[i]);
indx[a[i]]=i;
}
for(int i=0; iscanf("%d",&b[m-i-1]);
vis[b[m-i-1]]=1;
}
for(int i=1; i<=n; i++)
{
if(!vis[a[i]])
{
add(i,a[i]);
sum+=ask(i,a[i]);
}
}
for(int i=0; ifor(int i=m-1; i>=0; i--)
{
printf("%lld\n",ans[i]);
}
return 0;
}