POJ 3376 Finding Palindromes

题目大意:

给你N个字符串, 你可以两两连接得到N * N个字符串, 问之中回文串的数量. N个字符串的长度和加起来不超过2000000.

 

简要分析:

无比恶心的题啊...

我们顺次考虑每个字符串放在前面的情况. 假设字符串i放在前面, j放在后面, 那么这个串是回文有两种情况:

1) 若i的长度小于j, 则i是j反串的前缀, 且j反串剩下的后缀是回文串.

2) 若i的长度不小于j, 则j反串是i的前缀, 且i串剩下的后缀是回文串.

于是大致的思路就有了, 把所有串的反串丢到Trie里面, 每个结点额外记录两个值: 从这个点往下走到叶子, 有多少串是回文; 在这个点结束的字符串有多少. 这两个值就分别对应前文的两种情况了.

1) 在Trie中, 若i串在非叶子结点匹配完成, 则把该节点往下走有多少个回文串(即j反串的后缀!!!)累加到答案.

2) 在Trie中, 若在匹配i串时遇上在这个结点结束的字符串, 那么看i剩下的后缀是否是回文串, 若是, 则把在该点结束的字符串数目累加到答案.

在艰难的分析后, 问题转化成求某个串(i串和j反串)有哪些后缀是回文串. 第一反应是后缀数组, 把串和其反串连起来, 中间用奇葩字符隔开, 求一遍后缀数组, 设字符串长度为N, 则i后缀是回文串等价于i后缀与N+2后缀的LCP为N-i+1. 但是由于后缀数组巨大的空间开销和常数, 在这里用不是TLE就是MLE...于是囧了, 看来又是某种生僻算法了.

翻了下Discuss, 看到扩展KMP的字样. 百度一下发现是解决这么一个问题, 给串S和模式串T, 求S的所有后缀与T的LCP, 复杂度O(LenS + LenT). 这...不正是我们想要的吗! 令S为我们想知道哪些后缀是回文的那个串, 模式串T为其反串, 那么就看i后缀与T的LCP是否为LenS - i + 1了.

于是, 这个算法是这样的: 设下标都从0开始, S串已经处理到i后缀, ex[i]记录i后缀与T的LCP, 设i之前匹配的最远的位置是k, 则这个最远的位置p = k + ex[k] - 1. 假设我们手上还有个NX的数组next, next[i]表示T的i后缀与T的最长公共前缀, 下面开始推:

S[k..p] = T[0..p - k].

S[i..p] = T[i - k..p - k].

令next[i - k] = L, 则T[0..L - 1] = T[i - k..i - k + L - 1], S[i..i + L - 1] = T[0..L - 1].

接着我们看i + L - 1与p的关系.

1) 若i + L - 1 < p, 而p之前的位置都是枚举到过的, 所以ex[i]不会超过L, 直接ex[i] = L即可.

2) 否则, 因为p之后都是未探索的地区, 所以要往后匹配. S[i..p] = T[0..p-i], 于是令j = p - i + 1, 从S[i + j]与T[j]开始直接往后模拟匹配, 得出ex[i]值并更新k和p.

至于那个NX的next数组, 其实相当于母串S与T相同, 于是求法也就相同了, 就像KMP中的自身匹配求出pre一样.

呼~终于说完了. 只要知道了思路, 写的时候直接推导就可以了, 没必要记. 现在记不住的式子有两个了, 一个是扩展GCD, 一个是扩展KMP...

PS: 这还不是这题的恶心之处...我对最初代码大概依次有如下改动: 后缀数组改扩展KMP, TLE; string换成char数组, 用指针表示每个字符串, TLE; 去掉几个没必要的memset, TLE; 把Trie树改成邻接表用图来存, TLE...

最后我把某个int数组类型改成char, A了...

 

代码实现:

View Code
  1 #include <cstdio>
2 #include <cstdlib>
3 #include <cstring>
4 #include <algorithm>
5 using namespace std;
6
7 const int MAX_N = 2000000;
8 char buf[MAX_N + 1], tmp[MAX_N + 1];
9 int n, l[MAX_N], sz[MAX_N], ex[MAX_N], next[MAX_N];
10 char s[MAX_N], t[MAX_N];
11 long long ans = 0LL;
12
13 namespace trie {
14 const int MAX_V = MAX_N + 1, MAX_E = 10000000;
15 int ecnt, begin[MAX_V], to[MAX_E], next[MAX_E], end[MAX_V], cnt[MAX_V];
16 char val[MAX_V];
17
18 int node_idx, root;
19
20 void init() {
21 ecnt = 0;
22 memset(begin, -1, sizeof(begin));
23 node_idx = 0;
24 root = node_idx ++;
25 val[root] = -1;
26 }
27
28 void add_edge(int u, int v) {
29 next[ecnt] = begin[u];
30 begin[u] = ecnt;
31 to[ecnt ++] = v;
32 }
33
34 void ins(int sz) {
35 int pos = root;
36 for (int i = 0; i < sz; i ++) {
37 int t = s[i] - 'a';
38 if (ex[i] == sz - i) cnt[pos] ++;
39 bool exi = 0;
40 for (int now = begin[pos]; now != -1; now = next[now])
41 if (val[to[now]] == t) {
42 exi = 1;
43 pos = to[now];
44 }
45 if (!exi) {
46 int v = node_idx ++;
47 add_edge(pos, v);
48 val[v] = t;
49 pos = v;
50 }
51 }
52 end[pos] ++;
53 }
54
55 void go(int sz) {
56 int pos = root;
57 for (int i = 0; i < sz; i ++) {
58 int t = s[i] - 'a';
59 bool exi = 0;
60 for (int now = begin[pos]; now != -1; now = next[now])
61 if (val[to[now]] == t) {
62 exi = 1;
63 pos = to[now];
64 }
65 if (!exi) return;
66 if (end[pos]) {
67 if (i < sz - 1 && ex[i + 1] == sz - i - 1) ans += end[pos];
68 else if (i == sz - 1) ans += end[pos];
69 }
70 }
71 ans += cnt[pos];
72 }
73 }
74
75 void ex_kmp(int len) {
76 //memset(ex, 0, sizeof(int) * len), memset(next, 0, sizeof(int) * len);
77 next[0] = len;
78 next[1] = len - 1;
79 for (int i = 0; i < len - 1; i ++)
80 if (t[i] != t[i + 1]) {
81 next[1] = i;
82 break;
83 }
84 int j, k = 1, p, l;
85 for (int i = 2; i < len; i ++) {
86 p = k + next[k] - 1;
87 l = next[i - k];
88 if (i + l - 1 < p) next[i] = l;
89 else {
90 j = max(0, p + 1 - i);
91 while (i + j < len && t[i + j] == t[j]) j ++;
92 next[i] = j, k = i;
93 }
94 }
95 ex[0] = len;
96 for (int i = 0; i < len; i ++)
97 if (s[i] != t[i]) {
98 ex[0] = i;
99 break;
100 }
101 k = 0;
102 for (int i = 1; i < len; i ++) {
103 p = k + ex[k] - 1;
104 l = next[i - k];
105 if (i + l - 1 < p) ex[i] = l;
106 else {
107 j = max(0, p + 1 - i);
108 while (i + j < len && s[i + j] == t[j]) j ++;
109 ex[i] = j, k = i;
110 }
111 }
112 }
113
114 int main() {
115 //freopen("t.in", "r", stdin);
116 scanf("%d", &n);
117 int tot = 0;
118 for (int i = 0; i < n; i ++) {
119 scanf("%d%s", &sz[i], buf + tot);
120 l[i] = tot;
121 tot += sz[i];
122 }
123 trie::init();
124 for (int it = 0; it < n; it ++) {
125 for (int i = 0; i < sz[it]; i ++) s[i] = buf[l[it] + sz[it] - i - 1];
126 for (int i = 0; i < sz[it]; i ++) t[i] = buf[l[it] + i];
127 ex_kmp(sz[it]);
128 trie::ins(sz[it]);
129 }
130 for (int it = 0; it < n; it ++) {
131 for (int i = 0; i < sz[it]; i ++) s[i] = buf[l[it] + i];
132 for (int i = 0; i < sz[it]; i ++) t[i] = buf[l[it] + sz[it] - i - 1];
133 ex_kmp(sz[it]);
134 trie::go(sz[it]);
135 }
136 printf("%lld\n", ans);
137 return 0;
138 }

 

你可能感兴趣的:(find)