HZOJ-289:生日礼物

2023.7.21

#include
#include
long long arr[100005] = { 0 };
int l[100005] = { 0 }, r[100005] = { 0 };
long long heap[2][100005] = { 0 };
int st[100005] = { 0 };

void del(int x) {
	r[l[x]] = r[x];
	l[r[x]] = l[x];
	st[x] = 1;
	return;
}
void swap(long long* n, long long* m) {
	long long tem = *n;
	*n = *m;
	*m = tem;
	return;
}
void up_update(long long heap[][100005], int i, int n) {
	while (i / 2 > 0) {
		int ind = i, f = i / 2;
		if (heap[0][ind] <= heap[0][f]) ind = f;
		if (ind == i) break;
		swap(&heap[0][i], &heap[0][ind]);
		swap(&heap[1][i], &heap[1][ind]);
		i = ind;
	}
	return;
}

void down_update(long long heap[][100005], int i, int n) {
	while (i * 2 <= n) {
		int ind = i, l = i * 2, r = i * 2 + 1;
		if (heap[0][ind] >= heap[0][l]) ind = l;
		if (r <= n && heap[0][ind] >= heap[0][r]) ind = r;
		if (ind == i) break;
		swap(&heap[0][i], &heap[0][ind]);
		swap(&heap[1][i], &heap[1][ind]);
		i = ind;
	}
	return;
}

void heap_erase(long long heap[][100005], int* p) {
	swap(&heap[0][1], &heap[0][*p]);
	swap(&heap[1][1], &heap[1][*p]);
	(*p)--;
	down_update(heap, 1, *p);
	return;
}


int main() {
	int n, m;
	scanf("%d%d", &n, &m);
	scanf("%lld", &arr[1]);
	int i = 0, k = 1;
	for (i = 1; i < n; i++) {
		long long a;
		scanf("%lld", &a);
		if (a * arr[k] < 0) arr[++k] = a;
		else arr[k] += a;
	}
	long long cnt = 0, s = 0;
	n = k;
	for (i = 1; i <= n; i++) {
		if (arr[i] > 0) {
			s += arr[i];
			cnt++;
		}
		heap[0][i] = abs(arr[i]);
		heap[1][i] = i;
		l[i] = i - 1;
		r[i] = i + 1;
	}
	for (k = n / 2; k >= 1; k--) {
		down_update(heap, k, n);
	}
	k = n;
	while (cnt > m) {
		while (st[heap[1][1]]) heap_erase(heap, &n);
		i = heap[1][1];
		heap_erase(heap, &n);
		if ((l[i] > 0 && r[i] <= k) || arr[i] > 0) {
			s -= abs(arr[i]);
			cnt--;
			arr[i] += arr[l[i]] + arr[r[i]];
			heap[0][++n] = abs(arr[i]);
			heap[1][n] = i;
			up_update(heap, n, n);
			
			del(l[i]);
			del(r[i]);
		}
	}
	printf("%lld", s);
	return 0;
}

2023.10.6

#include 
#include 
#define MAX 100005
#define cmp <

int arr[MAX] = { 0 }, hea[MAX], l[MAX], r[MAX], index[MAX] = { 0 };

void del(int i) {
	r[l[i]] = r[i];
	l[r[i]] = l[i];
	index[i] = 1;
	return;
}

void swap(int* a, int* b) {
	int t = *a;
	*a = *b;
	*b = t;
	return;
}

void up_updata(int* hea, int n) {
	while (n / 2 >= 1) {
		if (abs(arr[hea[n]]) cmp abs(arr[hea[n / 2]])) {
			swap(&hea[n], &hea[n / 2]);
			n /= 2;
		}
		else break;
	}
	return;
}

void down_updata(int* ind, int i, int n) {
	while (i * 2 <= n) {
		int j = i, l = i * 2, r = l + 1;
		if (abs(arr[ind[l]]) cmp abs(arr[ind[j]])) j = l;
		if (r <= n && abs(arr[ind[r]]) cmp abs(arr[ind[j]])) j = r;
		if (j == i) break;
		swap(&ind[j], &ind[i]);
		i = j;
	}
	return;
}

void build_heap(int* arr, int n) {
	for (int i = n / 2; i >= 1; i--)
		down_updata(arr, i, n);
	return;
}

void del_hea(int* arr, int* n) {
	arr[1] = arr[(*n)--];
	down_updata(arr, 1, *n);
	return;
}

void add_hea(int* hea, int i, int* n) {
	hea[++(*n)] = i;
	up_updata(hea, *n);
	return;
}

int main() {
	int n, m, j = 1;
	scanf("%d %d", &n, &m);
	for (int i = 1; i <= n; i++) scanf("%d", &arr[i]);
	for (int i = 2; i <= n; i++) {
		if (arr[j] * arr[i] < 0) arr[++j] = arr[i];
		else arr[j] += arr[i];
	}
	n = j;

	int sum = 0, cnt = 0;
	for (int i = 1; i <= n; i++) {
		if (arr[i] > 0) sum += arr[i], cnt++;
		hea[i] = i;
		l[i] = i - 1, r[i] = i + 1;
	}
	int hea_n = n;
	build_heap(hea, hea_n);
	while (cnt > m) {
		while (index[hea[1]]) del_hea(hea, &hea_n);
		int ind = hea[1], left = l[ind], right = r[ind];
		del_hea(hea, &hea_n);
		if ((left > 0 && right <= n) || arr[ind] > 0) {
			sum -= abs(arr[ind]);
			cnt--;
			arr[ind] += arr[left] + arr[right];
			add_hea(hea, ind, &hea_n);
			del(left);
			del(right);
		}
	}
	printf("%d", sum);
	return 0;
}

你可能感兴趣的:(算法题,c语言,算法)