题目
对于一个数组 a 1 , a 2 , a 3 , … , a n a_1,a_2,a_3,\dots,a_n a1,a2,a3,…,an,动态删除其中的点,在每次删除前求出逆序对个数。
本题解使用CDQ分治
简述:
需求区间 [ l , r ] [l,r] [l,r]每个数的答案,将其分为两个子区间 [ l , m i d ] [l,mid] [l,mid]和 [ m i d + 1 , r ] [mid+1,r] [mid+1,r],先递归处理子区间内部,再计算两个区间之间的相互贡献
首先,发现每次删除一个数,那么只要用当前答案减去这个数的贡献即可
问题就转移到了求每个数的逆序对个数
对于每个数,求逆序对可分为两个部分:
显而易见地,三维偏序
在本题中,通过排序,位置可以很轻松地压掉
在CDQ分治中,可以消去值,即对值排序
注意,此时区间 [ l , m i d ] [l,mid] [l,mid]和 [ m i d + 1 , r ] [mid+1,r] [mid+1,r]内部的位置已经乱序,但CDQ分治是计算左区间与右区间之间的相互影响,不会计算区间内部,因此对于两区间来说,左区间的每一个数的位置都会比右区间的任意一个位置小,所以并没有关系
这时候,只需要计算删除时间的影响了
右区间对左区间的影响,其实具有单调性
因为若 a i < b i a_i
所以我们用树状数组统计,树状数组下标为删除时间,树状数组的值存储个数
那么当前节点的贡献就是删除时间(即下标)大于当前删除时间的总数
也就是树状数组的 s u m ( d e l i , m + 1 ) sum(del_i,m+1) sum(deli,m+1)
对于不删除的点,记其删除时间为 m + 1 m+1 m+1,所以求和求到 m + 1 m+1 m+1
记得开longlong
#include
using namespace std;
typedef long long LL;
#define lowbit(x) x & (-x)//与函数定义无区别
const int N = 100005;
int n, m, tr[N], pl[N];//pl[i]指值为i的数在数组中的位置
LL ans = 0;
struct node {
int a, b;
LL sum;
} p[N];//a 代表数组的值,b 代表删除时间
bool cmp1(node T1, node T2) { return T1.a < T2.a; }
bool cmp2(node T1, node T2) { return T1.b < T2.b; }
inline void update(int x, int y) { for (; x <= n + 1; x += lowbit(x)) tr[x] += y; }
inline int getsum(int x) {
int sum = 0;
for (; x; x -= lowbit(x)) sum += tr[x];
return sum;
}
inline void solve(int l, int r) {//CDQ分治
if (l == r) return;
int mid = l + r >> 1;
solve(l, mid), solve(mid + 1, r);//先解决子问题内部的事
sort(p + l, p + mid + 1, cmp1), sort(p + mid + 1, p + r + 1, cmp1);//分别按值排序
int j = mid + 1;
for (int i = l; i <= mid; i++) {
while (p[i].a > p[j].a && j <= r) update(p[j].b, 1), j++;//使左边大于右边,以左边为基准,计算右边对当前的贡献
p[i].sum += getsum(m + 1) - getsum(p[i].b);//前缀和做差
}
for (int i = mid + 1; i < j; i++) update(p[i].b, -1);//归零
j = mid;
for (int i = r; i >= mid + 1; i--) {
while (p[j].a > p[i].a && j >= l) update(p[j].b, 1), j--;//以右边为基准,计算左边对当前的贡献
p[i].sum += getsum(m + 1) - getsum(p[i].b);
}
for (int i = mid; i > j; i--) update(p[i].b, -1);//归零
}
int main() {
scanf("%d%d", &n, &m);
int x;
for (int i = 1; i <= n; i++) scanf("%d", &p[i].a), pl[p[i].a] = i;
for (int i = 1; i <= m; i++) scanf("%d", &x), p[pl[x]].b = i;
for (int i = 1; i <= n; i++)
if (!p[i].b)
p[i].b = m + 1;
for (int i = 1; i <= n; i++) ans += getsum(n + 1) - getsum(p[i].a), update(p[i].a, 1);//计算原始的值
for (int i = 1; i <= n; i++) update(p[i].a, -1);
solve(1, n);
sort(p + 1, p + n + 1, cmp2);//按删除时间排序,准备输出
for (int i = 1; i <= m; i++) printf("%lld\n", ans), ans -= p[i].sum;
return 0;
}