问题可以转换为求有多少个区间数字的总和除2向下取整大于等于最大值。或者解释为有多少个区间数字的总和大于等于最大值的两倍(但是若区间数字总和为奇数,需要算作减1)
启发式分治:
首先按最大值位置分治,遍历长度较短的一边,枚举它为一个端点,另一边二分算贡献即可。
复杂度大概\(nlog(n)^2\)
#pragma comment(linker, "/STACK:102400000,102400000")
#include
#define fi first
#define se second
#define endl '\n'
#define o2(x) (x)*(x)
#define BASE_MAX 30
#define mk make_pair
#define eb emplace_back
#define all(x) (x).begin(), (x).end()
#define clr(a, b) memset((a),(b),sizeof((a)))
#define iis std::ios::sync_with_stdio(false); cin.tie(0)
#define my_unique(x) sort(all(x)),x.erase(unique(all(x)),x.end())
using namespace std;
#pragma optimize("-O3")
typedef long long LL;
typedef pair pii;
inline LL read() {
LL x = 0;int f = 0;
char ch = getchar();
while (ch < '0' || ch > '9') f |= (ch == '-'), ch = getchar();
while (ch >= '0' && ch <= '9') x = (x << 3) + (x << 1) + ch - '0', ch = getchar();
return x = f ? -x : x;
}
inline void write(LL x) {
if (x == 0) {putchar('0'), putchar('\n');return;}
if (x < 0) {putchar('-');x = -x;}
static char s[23];
int l = 0;
while (x != 0)s[l++] = x % 10 + 48, x /= 10;
while (l)putchar(s[--l]);
putchar('\n');
}
int lowbit(int x) { return x & (-x); }
templateT big(const T &a1, const T &a2) { return a1 > a2 ? a1 : a2; }
templateT big(const T &f, const R &...r) { return big(f, big(r...)); }
templateT sml(const T &a1, const T &a2) { return a1 < a2 ? a1 : a2; }
templateT sml(const T &f, const R &...r) { return sml(f, sml(r...)); }
void debug_out() { cerr << '\n'; }
templatevoid debug_out(const T &f, const R &...r) {cerr << f << " ";debug_out(r...);}
#define debug(...) cerr << "[" << #__VA_ARGS__ << "]: ", debug_out(__VA_ARGS__);
#define print(x) write(x);
typedef unsigned long long uLL;
const LL INFLL = 0x3f3f3f3f3f3f3f3fLL;
const int HMOD[] = {1000000009, 1004535809};
const LL BASE[] = {1572872831, 1971536491};
const int mod = 998244353;
const int MOD = 1e9 + 7;
const int INF = 0x3f3f3f3f;
const int MXN = 3e5 + 7;
int n, q;
LL ar[MXN], dp[MXN][20], pos[MXN][20];
LL sum[MXN], ans, ret;
void init() {
for(int j = 1; j < 20; ++j) {
if((1<<(j-1)) > n) break;
for(int i = 1; i + (1<= dp[i+(1<<(j-1))][j-1]) {
dp[i][j] = dp[i][j-1];
pos[i][j] = pos[i][j-1];
}else {
dp[i][j] = dp[i+(1<<(j-1))][j-1];
pos[i][j] = pos[i+(1<<(j-1))][j-1];
}
}
}
}
inline int query(int l, int r) {
int k = log(r - l + 1) / log(2);
if (dp[l][k] >= dp[r - (1 << k) + 1][k]) return pos[l][k];
return pos[r - (1 << k) + 1][k];
}
void solve(int l, int r) {
if(l >= r) return;
if(l + 1 == r) {
ans += (ar[l] == ar[r]);
return;
}
int mid = query(l, r);
// debug(mid)
if(r - mid > mid - l) {
for(int i = l ; i <= mid; ++i) {
int L = mid + 1, R = r, M, res = mid;
if(i != mid) L = mid, res = mid - 1;
while(L <= R) {
M = (L + R) >> 1;
ret = sum[M] - sum[i-1];
if(ret & 1) -- ret;
if(ret < 2 * ar[mid]) res = M, L = M + 1;
else R = M - 1;
}
ans += r - res;
// debug(i, res)
}
}else {
for(int i = mid; i <= r; ++i) {
int L = l, R = mid - 1, M, res = mid;
if(i != mid) R = mid, res = mid + 1;
while(L <= R) {
M = (L + R) >> 1;
ret = sum[i] - sum[M-1];
if(ret & 1) -- ret;
if(ret < 2 * ar[mid]) res = M, R = M - 1;
else L = M + 1;
}
ans += res - l;
// debug(i, res)
}
}
solve(l, mid - 1), solve(mid + 1, r);
}
int main() {
#ifndef ONLINE_JUDGE
freopen("/home/cwolf9/CLionProjects/ccc/in.txt", "r", stdin);
//freopen("/home/cwolf9/CLionProjects/ccc/out.txt", "w", stdout);
#endif
int tim = read();
while(tim --) {
n = read();
// debug("***")
ans = 0;
for(int i = 1; i <= n; ++i) ar[i] = read(), sum[i] = sum[i-1] + ar[i], dp[i][0] = ar[i], pos[i][0] = i;
init();
solve(1, n);
print(ans)
// debug(ans)
}
#ifndef ONLINE_JUDGE
cout << "time cost:" << clock() << "ms" << endl;
#endif
return 0;
}
https://codeforces.com/blog/entry/44351