学习了博主:MyZee , shengweison 的文章
线段树或树状数组求逆序数
假设给你一个序列 6 1 2 7 3 4 8 5, 首先我们先手算逆序数, 设逆序数为 N;
6的前面没有比他大的数 N +=0
1的前面有一个比他大的数 N+=1
2的前面有一个比他大的数 N+=1
7的前面没有比他大的数 N+=0
... 最后得到 N = 0 + 1 + 1 + 0 + 2 + 2 + 0 + 3 = 9
其实我们可用用线段树,或者树状数组模拟这个过程。 又因为线段树和树状数组的效率较高,所以可行
假设我们将 序列 6 1 2 7 3 4 8 5 存入数组num【】 中, num【1】=6 , num【2】=1...
那么每次,我们可以将 num【i】 插入到 线段树或者树状数组中,并赋值为 1,
我们求和sum,sum等于线段树中 1 到 num【i】的和 , 那么这个 sum 表示的值就是当前比num【i】小的数量(包括它本身);
而当前一共有 i 个数 , 所以 当前 比num【i】大的数量就是 i - sum;
所以 我们统计所有的 i - sum , 它们的和就是逆序数。 模拟了上面手算的过程。
【线段树的关键代码】
<pre name="code" class="cpp">int count=0; for(int i=1;i<=n;i++){ Insert(1,num[i],num[i],1); //插入 num[i],并赋值1 count+=(i-(Query(1,1,num[i]))); }
【树状数组的关键代码】
long long ans=0; for(int i=1;i<=n;i++){ add(N[i].id); ans+=(i-Sum(N[i].id)); }
当然,这里查询的数 的 id 都是默认从 1 到 N 的,如果 题目要求输入的数超过这个范围,就需要用到离散化,
这个在后面的题目会介绍到。
【这里先给一个求1~n的逆序数】
//线段树 求逆序数 #include <iostream> #include <cstdio> #include <cstring> #define L(a) a<<1 #define R(a) (a<<1)|1 const int maxn = 51000; int ans[maxn]; struct node{ int num,l,r; }tree[maxn<<2]; int n; void Build(int m,int l, int r){ tree[m].l=l; tree[m].r=r; if(tree[m].l==tree[m].r){ tree[m].num=0; return ; } int mid = (tree[m].l+tree[m].r)>>1; Build(L(m),l,mid); Build(R(m),mid+1,r); //并不要回溯, 建立空树 } void Insert(int m,int l,int r,int x){ if(tree[m].l==l&&tree[m].r==r){ tree[m].num+=x; return ; } int mid = (tree[m].l+tree[m].r)>>1; if(r<=mid) Insert(L(m),l,r,x); else if(l>mid) Insert(R(m),l,r,x); else{ Insert(L(m),l,mid,x); Insert(R(m),mid+1,r,x); } tree[m].num=tree[L(m)].num+tree[R(m)].num; } int Query(int m,int l,int r){ if(tree[m].l==l&&tree[m].r==r) return tree[m].num; int mid = (tree[m].l+tree[m].r)>>1; if(r<=mid) return Query(L(m),l,r); if(l>mid) return Query(R(m),l,r); return Query(L(m),l,mid)+Query(R(m),mid+1,r); } int main(){ int a,n,i,t; scanf("%d",&t); while(t--){ int k=0; scanf("%d",&n); memset(tree,0,sizeof(tree)); Build(1,1,n); for(int i=1;i<=n;i++) { scanf("%d",&ans[i]); } for(int i=1;i<=n;i++){ Insert(1,ans[i],ans[i],1);// 每个位置插入1 k+=(i - Query(1,1,ans[i])); } printf("%d\n",k); } return 0; }
#include <iostream> #include <cstdio> #include <cstring> #define L(a) a<<1 #define R(a) a<<1|1 using namespace std; int n; const int maxn = 5005; int num[maxn]; struct node{ int l,r,sum; }tree[maxn<<2]; void Build(int m,int l,int r){ tree[m].l=l; tree[m].r=r; if(tree[m].l==tree[m].r){ //如果当前节点的左右节点相同,即叶子节点 tree[m].sum=0; return ; } int mid = (tree[m].l+tree[m].r)>>1; Build(L(m),l,mid); Build(R(m),mid+1,r); } void Insert(int m,int l,int r,int x){ if(tree[m].l==l&&tree[m].r==r){ tree[m].sum+=x; return ; } int mid = (tree[m].l+tree[m].r)>>1; if(mid>=r) //这里是大于等于 Insert(L(m),l,r,x); else if(mid<l) Insert(R(m),l,r,x); else{ Insert(L(m),l,mid,x); Insert(R(m),mid+1,r,x); } tree[m].sum=tree[L(m)].sum+tree[R(m)].sum; } int Query(int m,int l,int r){ if(tree[m].l==l&&tree[m].r==r){ return tree[m].sum; } int mid = (tree[m].l+tree[m].r)>>1; //这里和 Insert 一样 if(mid>=r) return Query(L(m),l,r); if(mid<l) return Query(R(m),l,r); return Query(L(m),l,mid)+Query(R(m),mid+1,r); } int main(){ while(scanf("%d",&n)!=EOF){ memset(tree,0,sizeof(tree)); Build(1,1,n); for(int i=1;i<=n;i++){ scanf("%d",&num[i]); num[i]++; } int count=0; for(int i=1;i<=n;i++){ Insert(1,num[i],num[i],1); count+=(i-(Query(1,1,num[i]))); } int ans = count; for(int i=1;i<=n;i++){ num[i]--; count = count - num[i]*2 + n -1; if(count<ans) ans = count; } printf("%d\n",ans); } return 0; }
这个需要用到离散化,
建立一个结构体包含val和id, val就是输入的数,id表示输入的顺序。然后按照val从小到大排序,如果val相等,那么就按照id排序。
如果没有逆序的话,肯定id是跟i(表示拍好后的顺序)一直一样的,如果有逆序数,那么有的i和id是不一样的。所以,利用树状数组的特性,我们可以简单的算出逆序数的个数。
如果还是不明白的话举个例子。(输入4个数)
输入:9 -1 18 5
输出 3.
输入之后对应的结构体就会变成这样
val:9 -1 18 5
id: 1 2 3 4
排好序之后就变成了
val : -1 5 9 18
id: 2 4 1 3
2 4 1 3 的逆序数 也是3
之后再利用树状数组的特性就可以解决问题了;
因为数字可能有重复, 所以添加操作不再单纯的置为1 ,而是 ++;
【源代码】
#include <iostream> #include <cstdio> #include <algorithm> #include <cstring> using namespace std; int n; const int maxn = 1000005; struct node{ int val,id; }N[maxn]; int c[maxn]; int cmp(const node &a,const node& b ){ if(a.val==b.val) return a.id<b.id; return a.val<b.val; } int lowbit(int x){ return x&(-x); } void add(int x){ while(x<=n){ c[x]++; //可能有重复,因为用++ 不用 = 1; x+=lowbit(x); } return; } int Sum(int x){ int ans = 0; while(x>0){ ans+=c[x]; x-=lowbit(x); } return ans; } int main(){ int T; scanf("%d",&T); while(T--){ scanf("%d",&n); for(int i=1;i<=n;i++){ scanf("%d",&N[i].val); N[i].id=i; } sort(N+1,N+n+1,cmp);//从 1 - n 排序 memset(c,0,sizeof(c)); //不要忘记初始化 long long ans=0; //用 int 会爆掉 for(int i=1;i<=n;i++){ add(N[i].id); ans+=(i-Sum(N[i].id)); } printf("%lld\n",ans); } return 0; }