POJ 3376 Finding Palindromes(扩展kmp+trie)

题目链接:http://poj.org/problem?id=3376

题意:给你n个字符串m1、m2、m3...mn 求S = mimj(1=<i,j<=n)是回文串的数量

思路:我们考虑第i个字符串和第j个字符串能构成组合回文串要满足的条件:

1、i的长度小于j,那么i一定是j的反串的前缀,且j的反串剩下的后缀是回文串

2、i的长度等于j,那么i等于j的反串

3、i的长度大于j,那么j的反串一定是i的前缀,且i串剩下的后缀是回文串

我们可以将这n个字符串插入trie,每个节点要维护两个值:value1. 到当前节点的字符串个数;value2. 当前节点后面的回文子串个数

我们用每个字符串的反串去trie上查找,要构成回文串有以下情况:

1、 此反串是其他串的前缀,那么组合回文串的数量就要加上value2

2、此反串的前缀是某些字符串,且反串剩下的后缀是回文串,那么组合回文串的数量要加上value1

3、2的特例:此反串的前缀是某些字符串,且反串剩下的后缀为空,同样要加上value1,这种情况可以和2一起处理

关键:

1、判断字符串的哪些后缀是回文串(用于更新value2),以及对应反串的哪些后缀是回文串(当面临第二种情况时,可直接判断后缀否为回文串)

2、如何更新value1和value2(借助1的结果)

  1 #include <cstdio>

  2 #include <cstring>

  3 #include <algorithm>

  4 using namespace std;

  5 typedef long long LL;

  6 const int MAXN = 2000005;

  7 const int KIND = 26;

  8 

  9 struct TrieNode

 10 {

 11     int num;    // 到当前节点的字符串个数

 12     int cnt;    // 当前节点后面回文子串个数

 13     TrieNode* nxt[26];

 14 };

 15 

 16 TrieNode node[MAXN];     // 避免动态申请空间的时间消耗

 17 TrieNode* root;           // trie树的根节点

 18 int bg[MAXN];             // bg[i]第i+1个字符串开始的位置

 19 int ed[MAXN];             // ed[i]第i+1个字符串结束的位置

 20 bool flag[2][MAXN];       // flag[0][i]为true表示原串后面为回文串   flag[1][i]表示反串

 21 char S[MAXN];             // 存放原串

 22 char T[MAXN];             // 存放反串

 23 int nxt[MAXN];            // 存放next数组

 24 int extend[MAXN];         // 用于判断是否为回文子串

 25 LL ans;                   // 保存结果

 26 int tot;                  // node数组的下标

 27 

 28 void GetNext(char* T, int lhs, int rhs)

 29 {

 30     int j = 0;

 31     while (lhs + j + 1 <= rhs && T[lhs + j] == T[lhs + j + 1]) ++j;

 32     nxt[lhs + 1] = j;

 33     int k = lhs + 1;

 34     for (int i = lhs + 2; i <= rhs; ++i)

 35     {

 36         int p = nxt[k] + k - 1;

 37         int L = nxt[lhs + i - k];

 38         if (L + i < p + 1) nxt[i] = L;

 39         else

 40         {

 41             j = max(0, p - i + 1);

 42             while (i + j <= rhs && T[lhs + j] == T[i + j]) ++j;

 43             nxt[i] = j;

 44             k = i;

 45         }

 46     }

 47 }

 48 

 49 void ExtendKMP(char* S, char* T, int lhs, int rhs, bool sign)

 50 {

 51     GetNext(T, lhs, rhs);

 52     int j = 0;

 53     while (j + lhs <= rhs && S[j + lhs] == T[j + lhs]) ++j;

 54     extend[lhs] = j;

 55     int k = lhs;

 56     for (int i = lhs + 1; i <= rhs; ++i)

 57     {

 58         int p = extend[k] + k - 1;

 59         int L = nxt[lhs + i - k];

 60         if (L + i < p + 1) extend[i] = L;

 61         else

 62         {

 63             j = max(0, p - i + 1);

 64             while (i + j <= rhs && S[i + j] == T[lhs + j]) ++j;

 65             extend[i] = j;

 66             k = i;

 67         }

 68     }

 69     for (int i = lhs; i <= rhs; ++i)

 70     {

 71         if (extend[i] == rhs - i + 1)

 72             flag[sign][i] = true;

 73     }

 74 }

 75 

 76 void Insert(char S[], int lhs, int rhs)

 77 {

 78     TrieNode* temp = root;

 79     for (int i = lhs; i <= rhs; ++i)

 80     {

 81         int ch = S[i] - 'a';

 82         temp->cnt += flag[0][i];    // 更新当前节点后面回文子串的数目

 83         if (temp->nxt[ch] == NULL) temp->nxt[ch] = &node[tot++];

 84         temp = temp->nxt[ch];

 85     }

 86     ++temp->num; // 更新到当前节点的字符串数目

 87 }

 88 

 89 void Search(char S[], int lhs, int rhs)

 90 {

 91     TrieNode* temp = root;

 92     for (int i = lhs; i <= rhs; ++i)

 93     {

 94         int ch = S[i] - 'a';

 95         temp = temp->nxt[ch];

 96         if (temp == NULL) break;

 97         if ((i < rhs && flag[1][i + 1]) || i == rhs)

 98             ans += temp->num;

 99     }

100     if (temp) ans += temp->cnt;

101 }

102 

103 int main()

104 {

105     int n;

106     while (scanf("%d", &n) != EOF)

107     {

108         // 初始化

109         tot = 0;

110         ans = 0;

111         memset(node, 0, sizeof(node));

112         memset(flag, 0, sizeof(flag));

113         root = &node[tot++];

114 

115         int l = 0;

116         int L = 0;

117         for (int i = 0; i < n; ++i)

118         {

119             // 输入一组数据

120             scanf("%d", &l);

121             scanf("%s", S + L);

122 

123             // 生成反串

124             for (int j = 0; j < l; ++j)

125                 T[L + j] = S[L + l - 1 - j];

126 

127             bg[i] = L;

128             ed[i] = L + l - 1;

129 

130 

131             ExtendKMP(S, T , bg[i], ed[i], 0);

132             ExtendKMP(T, S , bg[i], ed[i], 1);

133             Insert(S, bg[i], ed[i]);

134 

135             L += l;

136         }

137 

138         for (int i = 0; i < n; ++i)

139             Search(T, bg[i], ed[i]);

140 

141         printf("%lld\n", ans);

142     }

143     return 0;

144 }

 

你可能感兴趣的:(find)