函数式线段树的应用
#include <iostream> #include <cstdio> #include <cstdlib> #include <cmath> #include <queue> #include <algorithm> #include <vector> #include <cstring> #include <stack> #include <cctype> #include <utility> #include <map> #include <string> #include <climits> #include <set> #include <string> #include <sstream> #include <utility> #include <ctime> #include <bitset> using std::priority_queue; using std::vector; using std::swap; using std::stack; using std::sort; using std::max; using std::min; using std::pair; using std::map; using std::string; using std::cin; using std::cout; using std::set; using std::queue; using std::string; using std::stringstream; using std::make_pair; using std::getline; using std::greater; using std::endl; using std::multimap; using std::deque; using std::unique; using std::lower_bound; using std::random_shuffle; using std::bitset; using std::upper_bound; using std::multiset; typedef long long LL; typedef unsigned long long ULL; typedef unsigned UN; typedef pair<int, int> PAIR; typedef multimap<int, int> MMAP; typedef LL TY; typedef long double LF; const int MAXN(2000010); const int MAXM(50010); const int MAXE(150010); const int MAXK(6); const int HSIZE(13131); const int SIGMA_SIZE(4); const int MAXH(20); const int INFI((INT_MAX-1) >> 1); const ULL BASE(31); const LL LIM(1e13); const int INV(-10000); const int MOD(31313); const double EPS(1e-7); const LF PI(acos(-1.0)); template<typename T> inline void checkmax(T &a, T b){if(b > a) a = b;} template<typename T> inline void checkmin(T &a, T b){if(b < a) a = b;} template<typename T> inline T ABS(const T &a){return a < 0? -a: a;} int ls[MAXN], rs[MAXN], sum[MAXN], root[100010]; int rear; void build(int l, int r, int &rt) { rt = rear++; sum[rt] = 0; if(l == r) return; int m = (l+r) >> 1; build(l, m, ls[rt]); build(m+1, r, rs[rt]); } void updata(int l, int r, int val, int prt, int &rt) { rt = rear++; sum[rt] = sum[prt]+1; ls[rt] = ls[prt]; rs[rt] = rs[prt]; if(l == r) return; int m = (l+r) >> 1; if(val <= m) updata(l, m, val, ls[prt], ls[rt]); else updata(m+1, r, val, rs[prt], rs[rt]); } LL inv; void query1(int l, int r, int val, int lrt, int rrt) { if(l == r) return; int m = (l+r) >> 1; if(val <= m) query1(l, m, val, ls[lrt], ls[rrt]); else { inv += sum[ls[rrt]]-sum[ls[lrt]]; query1(m+1, r, val, rs[lrt], rs[rrt]); } } void query2(int l, int r, int val, int lrt, int rrt) { if(l == r) return; int m = (l+r) >> 1; if(val <= m) { inv += sum[rs[rrt]]-sum[rs[lrt]]; query2(l, m, val, ls[lrt], ls[rrt]); } else query2(m+1, r, val, rs[lrt], rs[rrt]); } int arr[100010], tab[100010]; int main() { int n; LL K; while(~scanf("%d%I64d", &n, &K)) { for(int i = 1; i <= n; ++i) { scanf("%d", arr+i); tab[i-1] = arr[i]; } sort(tab, tab+n); int tn = unique(tab, tab+n)-tab; for(int i = 1; i <= n; ++i) arr[i] = lower_bound(tab, tab+tn, arr[i])-tab+1; rear = 0; build(1, tn, root[0]); for(int i = 1; i <= n; ++i) updata(1, tn, arr[i], root[i-1], root[i]); LL ans = 0, ret = 0; int p1 = 1, p2 = n+1; while(true) { LL temp = ret; temp += arr[1] > arr[p2-1]? 1: 0; inv = 0; query1(1, tn, arr[p2-1], root[p2-2], root[n]); temp += inv; if(1 >= p2-1 || temp > K) break; ret = temp; --p2; } while(p1 < n) { ans += n-p2+1; ++p1; if(p2 > n) continue; if(p1 >= p2) { inv = 0; query2(1, tn, arr[p2], root[0], root[p1-1]); ret -= inv; inv = 0; query1(1, tn, arr[p2], root[p2-1], root[n]); ret -= inv; ++p2; } inv = 0; query2(1, tn, arr[p1], root[0], root[p1]); ret += inv; inv = 0; query1(1, tn, arr[p1], root[p2-1], root[n]); ret += inv; while(p2 <= n && ret > K) { inv = 0; query2(1, tn, arr[p2], root[0], root[p1]); ret -= inv; inv = 0; query1(1, tn, arr[p2], root[p2-1], root[n]); ret -= inv; ++p2; } } printf("%I64d\n", ans); } return 0; }