传送门
题目大意:给出三个1~n的排列,求有多少个数对,在这三个排列中的相对位置相同。
处理出每一个数在三个排列中的位置,就形成了n个三维点对
然后就是一个三维偏序问题了,用cdq分治+bit求解
做完之后发现有一个更厉害的只用bit的方法
考虑容斥,答案=总数-不符合条件的对数
然后又知道不符合条件的点对一定是在两个排列里相对位置相同,在一个中和另外两个不同,那么两两统计三个排列,相当于每一个点对都被统计了2次
那么关于一个点对x,y,若它们在两个排列中
____y__u______x____k
_______x___________y
这样从后向前扫第二个排列,扫到y的时候在第一个排列y的位置+1,然后扫到x的时候查询第一个排列x位置的前缀和,前缀和即为不满足条件的对数
所以就对扫到的每一个数都这样做就行了
cdq分治
#include
#include
#include
#include
#include
using namespace std;
#define N 200005
int n;
struct data{int x,y,z;}p[N],a[N],b[N];
int C[N],ch[N];
long long ans;
int cmpx(data a,data b)
{
return a.xx;
}
int cmpy(data a,data b)
{
return a.yy;
}
void add(int loc,int val)
{
for (int i=loc;i<=n;i+=i&-i)
C[i]+=val;
}
int query(int loc)
{
int ans=0;
for (int i=loc;i>=1;i-=i&-i)
ans+=C[i];
return ans;
}
void cdq(int l,int r)
{
if (l>=r) return;
int mid=(l+r)>>1;
cdq(l,mid);
int acnt=0,bcnt=0;
for (int i=l;i<=mid;++i) a[++acnt]=p[i];
for (int i=mid+1;i<=r;++i) b[++bcnt]=p[i];
sort(a+1,a+acnt+1,cmpy);
sort(b+1,b+bcnt+1,cmpy);
int pa=1,pb=1,tot=0;
while (pb<=bcnt)
{
while (pa<=acnt&&a[pa].y<=b[pb].y)
{
add(a[pa].z,1);
ch[++tot]=a[pa].z;
++pa;
}
ans+=(long long)query(b[pb].z);
++pb;
}
for (int i=1;i<=tot;++i)
add(ch[i],-1);
cdq(mid+1,r);
}
int main()
{
scanf("%d",&n);
for (int i=1;i<=n;++i)
{
int x;scanf("%d",&x);
p[x].x=i;
}
for (int i=1;i<=n;++i)
{
int x;scanf("%d",&x);
p[x].y=i;
}
for (int i=1;i<=n;++i)
{
int x;scanf("%d",&x);
p[x].z=i;
}
sort(p+1,p+n+1,cmpx);
cdq(1,n);
printf("%lld\n",ans);
}
bit
#include
#include
#include
#include
#include
using namespace std;
#define LL long long
#define N 200005
int n;
int a[4][N],loc[N][4],C[N];
LL ans;
void add(int loc,int val)
{
for (int i=loc;i<=n;i+=i&-i)
C[i]+=val;
}
int query(int loc)
{
int ans=0;
for (int i=loc;i>=1;i-=i&-i)
ans+=C[i];
return ans;
}
LL solve(int id,int jd)
{
LL ans=0;
memset(C,0,sizeof(C));
for (int i=n;i>=1;--i)
{
int x=loc[a[jd][i]][id];
ans+=(LL)query(x),add(x,1);
}
return ans;
}
int main()
{
scanf("%d",&n);
for (int i=1;i<=3;++i)
for (int j=1;j<=n;++j) scanf("%d",&a[i][j]),loc[a[i][j]][i]=j;
ans=(LL)n*(n-1);
ans-=solve(1,2)+solve(1,3)+solve(2,3);
printf("%lld\n",ans>>1LL);
}