今天刚学KMP,做点输出,一方面查找漏洞,加深理解,另一方面帮助大家学习。
给你两个字符串 n n n, m m m,问你在 n n n是否包含 m m m
例如 n = " a b c a b d a b c a b c " , m = " a b c a b c " n="abcabdabcabc",m="abcabc" n="abcabdabcabc",m="abcabc"
朴素解法:将m与n逐个对比
n| a b c a b d a b c a b c
m| a b c a b c
^(指针)
n| a b c a b d a b c a b c
m| a b c a b c
^(指针)
n| a b c a b d a b c a b c
m| a b c a b c
^(指针)
下面是朴素解法后面的其中几步
n| a b c a b d a b c a b c
m| a b c a b c
^(指针)
n| a b c a b d a b c a b c
m| a b c a b c
^(指针)
可以计算出,朴素算法的时间复杂度是 O ( n m ) O(nm) O(nm)
然而,在朴素解法第3步中,我们其实可以直接将 m m m串往后移动3个位置而不是1个位置
依据:在前面比对完的的 a b c a b abcab abcab串中,前缀ab和后缀ab相同,也即该串的最大公共前缀后缀长度为2,因此我们可以直接跳过中间的 b 、 c b、c b、c,也即将 m m m串往后移动2个位置。
n| a b c a b d a b c a b c
m| a b c a b c
* * ^(指针)
这时,我们仍然可以保证, m m m串前面的 a b ab ab与 n n n串后面的 a b ab ab(上面*标记的两个位置)依旧可以对应,于是可以使指针不回退,继续向前搜素。正因为KMP算法中的指针并没有回退,而是一直向前或暂时暂停,因此具有 O ( n + m ) O(n+m) O(n+m)的优秀时间复杂度
了解了KMP算法的原理之后,要做的就是计算 m m m串每个位置前缀字符串的最大前缀后缀公共长度,例如:
a b c a b c
a| 0
a b| 0
a b c| 0
a b c a| 1 (a=a)
a b c a b| 2 (ab=ab)
a b c a b c| 3 (abc=abc)
如果能够预处理出上面的数据,那么我们匹配时,就可以最大程度的往后移,节省时间。当然,这个预处理也是KMP算法的核心(也最难理解)。
这个预处理数组在算法书中被记作 n e x t next next数组, n e x t [ i ] next[i] next[i]表示 m m m串前 i i i个字符的最大公共长度,例如对应例子中的 m m m串, n e x t [ 1 ] = n e x t [ 2 ] = n e x t [ 3 ] = 0 , n e x t [ 4 ] = 1 , n e x t [ 5 ] = 2 , n e x t [ 6 ] = 3 next[1]=next[2]=next[3]=0,next[4]=1,next[5]=2,next[6]=3 next[1]=next[2]=next[3]=0,next[4]=1,next[5]=2,next[6]=3。
假设现在已经得到 n e x t [ j ] next[j] next[j]了,怎么递推 n e x t [ j + 1 ] next[j+1] next[j+1]呢?
分两种情况递推:
1. m [ j + 1 ] = m [ n e x t [ j ] ] m[j+1]=m[next[j]] m[j+1]=m[next[j]]
看上去很复杂,我们来看具体例子
m| a b c a b c
^(指针)
假设我们已经求出 n e x t [ 5 ] = 2 next[5]=2 next[5]=2,(显然,可以直接看出这个结果,为什么前面可以求得 n e x t [ 5 ] = 2 next[5]=2 next[5]=2不重要,重要的是怎么用前者递推后者),当 m [ j + 1 ] = m [ n e x t [ j ] ] m[j+1]=m[next[j]] m[j+1]=m[next[j]]时,也就是 m [ 5 ] = m [ 2 ] m[5]=m[2] m[5]=m[2],这时候最大公共长度也就是在前面的基础上+1, n e x t [ 6 ] = n e x t [ 5 ] + 1 next[6]=next[5]+1 next[6]=next[5]+1。应该不难理解。
很巧妙的一点是, n e x t [ 5 ] = 2 next[5]=2 next[5]=2表示的是前5个字符的最大公共长度是2,这时候的 m [ n e x t [ 5 ] ] m[next[5]] m[next[5]]恰巧又是最大前缀的后面那个字符。因此要比对 j + 1 j+1 j+1位置和上一次最大前缀后一个字符是否相同,就是条件中的 m [ j + 1 ] = m [ n e x t [ j ] ] m[j+1]=m[next[j]] m[j+1]=m[next[j]]
看代码:
void getnxt()
{
ll lenb = b.size(); //求b字符串对应的nxt数组
ll j = 0, k = -1; //k表示上一次的最大公共长度,也就是nxt[j-1]
nxt[0] = -1;
while (j < lenb)
{
if (k == -1 || b[j] == b[k]) //如果下一位相同,那么就+1
nxt[++j] = ++k;
else //如果不同,也就是下面要讲的情况2
k = nxt[k];
}
}
2.当然就是 m [ j + 1 ] ≠ m [ n e x t [ j ] ] m[j+1]\ne m[next[j]] m[j+1]=m[next[j]]
这是核心(难点)中的核心(难点)
为了方便说明,我们以字符串"abaaba#abaabaa"为例
a b a a b a # a b a a b a a
----- ----- ----- -----
1 2 3 4
----------- -----------
5 6
假设我们现在已经求得 k = n e x t [ 13 ] = 6 k=next[13]=6 k=next[13]=6,要递推得到 n e x t [ 14 ] next[14] next[14]
上面的代码已经给出结论,我们需要不断的回溯 k = n e x t [ k ] k=next[k] k=next[k]直到遇到 − 1 -1 −1或者匹配成功。
为什么是 k = n e x t [ k ] k=next[k] k=next[k]呢?
看上面的例子中,虽然 # 和 a a a不相同,但是我们已经知道 a a a前面的 n e x t [ 13 ] next[13] next[13]个字符和整个字符串最前面的 n e x t [ 13 ] next[13] next[13]个字符是相同的(也就是 5 5 5和 6 6 6部分)。在上面的例子中,又有 k = n e x t [ k ] = n e x t [ n e x t [ 13 ] ] = n e x t [ 6 ] = 3 k=next[k]=next[next[13]]=next[6]=3 k=next[k]=next[next[13]]=next[6]=3,这时候就有,前 6 6 6个字符的最大前缀后缀公共长度为 3 3 3,也就是 1 = 2 1=2 1=2,又有 5 = 6 5=6 5=6,所以得到 1 = 4 1=4 1=4(匹配成功)所以要只需从 1 1 1和 4 4 4的后面继续对比即可。当然也可能遇到 − 1 -1 −1,也就是回溯到 n e x t [ 0 ] = − 1 next[0]=-1 next[0]=−1也不能匹配,那么下一个位置的最大公共长度自然就是 0 0 0了
/*
* @Author: hesorchen
* @Date: 2020-07-02 22:19:34
* @LastEditTime: 2020-07-10 21:13:02
* @Description: https://hesorchen.github.io/
*/
#include
using namespace std;
#define ll long long
ll nxt[1000100];
string a, b;
void getnxt()
{
ll lenb = b.size(); //求b字符串对应的nxt数组
ll j = 0, k = -1; //k表示上一次的最大公共长度,也就是nxt[j-1]
nxt[0] = -1;
while (j < lenb)
{
if (k == -1 || b[j] == b[k]) //如果下一位相同,那么就+1
nxt[++j] = ++k;
else //如果不同,回溯找最优(长)的匹配子串
k = nxt[k];
}
}
bool kmp()
{
getnxt();
ll lena = a.size(), lenb = b.size();
ll i = 0, j = 0; //i指向a字符串,j指向b字符串
while (i < lena)
//可以发现 i指针从没有回退,保证了KMP的时间复杂度
{
if (j == -1 || a[i] == b[j])
i++, j++;
else
j = nxt[j]; //从最长公共长度后缀下一位开始继续找
if (j == lenb) //找到
return true;
}
return false;
}
int main()
{
while (cin >> a >> b)
cout << (kmp() ? "yes" : "no") << endl;
return 0;
}