比较可惜昨天比赛的时候时间不够了,在比赛结束之后五分钟找出了bug提交通过了。然并软;
首先这题说b数组的后一项要么等于前一项,要么等于前一项加一,而且如果a[i] == a[j] ,那么b[i] == b[j],所以如果a[i] == a[j],b[i]到b[j]这个区间的值都是一样的,可以看做一个整体;
那么这题要求的不就是2^(区间个数 - 1)吗;
刚看这题就觉得区间合并用并查集,但是当时思路不够清晰后来用了线段树ac掉了,今天就把两种方法的代码都贴上;
- 线段树解法
1102E - 22 GNU C++11 Happy New Year! 265 ms 9428 KB #include "bits/stdc++.h" using namespace std; typedef long long LL; const int INF = 0x3f3f3f3f; const int MOD = 998244353; //这里的tree其实就是一个懒标记 int tree[800005]; map<int, int> mp; int n, m, L, R, cnt; //查询包含q的区间前端 int queryHead(int l, int r, int id, int q) { if (tree[id] != 0) { return tree[id]; } int mid = l + r >> 1; if (q <= mid) { return queryHead(l, mid, id << 1, q); } else { return queryHead(mid + 1, r, id << 1 | 1, q); } } //把区间[L, R]的值修改为L; void update(int l, int r, int id) { if (l >= L && r <= R) { tree[id] = L; return; } int mid = l + r >> 1; if (L <= mid) { update(l, mid, id << 1); } if (R > mid) { update(mid + 1, r, id << 1 | 1); } } // 查询线段树中包含q的节点的区间末端 int queryTail(int l, int r, int id, int q) { if (tree[id] != 0) { return r; } int mid = l + r >> 1; if (q <= mid) { return queryTail(l, mid, id << 1, q); } else { return queryTail(mid + 1, r, id << 1 | 1, q); } } // 快速幂 int quick_pow(int n, int m) { int ans = 1; while (m) { if (m & 1) { ans = 1LL * ans * n % MOD; } n = 1LL * n * n % MOD; m >>= 1; } return ans; } int main() { scanf("%d", &n); for (int i = 1; i <= n; i++) { scanf("%d", &m); L = mp.count(m) ? queryHead(1, n, 1, mp[m]) : i; R = i; update(1, n, 1); mp[m] = i; } int mx = -1; for (int i = 1; i <= n; i = queryTail(1, n, 1, i) + 1) { int k = queryHead(1, n, 1, i); // 因为这题区间合并,这里的k得到的不是合并后的区间末端,只是线段树中的区间末端;所以要比较是否和上一个线段树区间属于同一区间 if (k != mx) { mx = k; cnt++; } } printf("%d\n", quick_pow(2, cnt - 1)); return 0; }
- 并查集解法
1102E - 22 GNU C++11 Happy New Year! 171 ms 7100 KB #include "bits/stdc++.h" using namespace std; typedef long long LL; const int INF = 0x3f3f3f3f; const int MOD = 998244353; int pre[200005], cnt; map<int, int> mp; int find(int id) { if (pre[id] == 0) { return id; } return pre[id] = find(pre[id]); } int quick_pow(int n, int m) { int ans = 1; while (m) { if (m & 1) { ans = 1LL * ans * n % MOD; } n = 1LL * n * n % MOD; m >>= 1; } return ans; } int main() { int n, m; scanf("%d", &n); for (int i = 1; i <= n; i++) { scanf("%d", &m); int head = mp.count(m) ? find(mp[m]) : i; int tail = i; while (true) { int x = find(tail); if (x == head) { break; } pre[x] = head; tail = x - 1; } mp[m] = i; } for (int i = n; i > 0; i = find(i) - 1) { cnt++; } printf("%d\n", quick_pow(2, cnt - 1)); return 0; }