KMP算法,又称模式匹配算法,能够在线性时间内判定字符串 T 是否为 S 的子串,并求出字符串 T 在 S 中各次出现的位置。
KMP算法比较晦涩难懂。本文对于思想介绍略简,侧重于实现。
问题模型: 给定两个字符串 S 和 T ,试求出 T 在 S 中第一次出现的位置。
上述问题模型是模式串匹配最基础的模型,即单模式串匹配问题,这类问题是KMP算法以及字符串Hash大展身手的题型。
算法思路1:Hash
设|S| = n , |T| = m。如果不考虑冲突,那么我们可以将 S 的所有长度为 m 的子串hash值都求出来,复杂度为O(N)。将这 n-m+1 个子串与T的hash值在O(1)的时间内一一比对,即可通过hash值是否相同来判断是否匹配成功。
但实际上如果n和m很大(1e6),那么散列值冲突是不可避免的,此时需要二次判断或者通过其他方法(构造更好的散列函数)来在保证速度的情况下提升正确性。
算法思路2:KMP
设|S| = n , |T| = m。首先考虑一个朴素算法,那就是将字符串 S 中的每一个长度为m的子串都与 T 进行一次匹配,失配后再匹配下一个,复杂度O(NM)。
手动模拟一下可以发现,上述做法中指向字符串 S 的指针和 T 的指针都有回退 [ 1 ] ^{[1]} [1],但实际上我们并不需要发生回退,KMP算法就是通过防止指针回退来提升朴素算法效率的。
假设我们 S[i] 和 T[j+1] 发生了失配,如果我们知道 “T 中以 j 为末尾的真子串” 和 T[1, j] 的最长公共前缀的长度(假设为len,len一定小于 j ),那么显然 T[1, len] = S[i-len+1, i];于是此时的 j = len,接着匹配即可。我们用nex数组(见下文)来存放 T 对应位置的“len”。
详细的讲,KMP算法分为两步:
[1] 指针回退:在朴素做法中,如果发生失配,则要将指向 S 串的指针回退到当前子串起始位置,并右移至下一个子串起始位置,同理指向 T 的指针也要回到起始位置。
首先要明白什么是Next数组(以下简称nex数组)。
nex[i]表示“T 中以 i 结尾的非前缀子串”与“T 的前缀”能够匹配的最长长度,即:nex[i] = max{j},其中j < i 并且 T[i-j+1, i] = T[i, j]。
跳过:nex数组起到什么辅助作用,为什么要用nex数组?
nex 数组的求法
代码块
void getNex(const char *s){
/*更新模式串s的nex数组*/
int len = strlen(s);
memset(nex,0,sizeof nex);
for(int i = 2,j = 0;i < len;i++){
while(j > 0 && s[i] != s[j+1]) j = nex[j];
if(s[i] == s[j+1]) j++;
nex[i] = j;
}
}
按照前面的定义, f[i] 表示“S 中以 i 结尾的子串”与“ T 的前缀”能够匹配的最长长度。可以发现 f 数组和 nex 数组定义是一致的,因此他们的求解过程也基本一致。
代码块
void getF(const char* S,const char *T){
/*求解 f 数组,S是目标串,T是模式串*/
memset(f,0,sizeof f);
int len1 = strlen(S),len2 = strlen(T);
for(int i = 1,j = 0;i < len1;i++){
while(j > 0 && (j == len2 || S[i] != T[j+1])) j = nex[j];
if(S[i] == T[j+1]) j++;
f[i] = j;
}
}
测试地址
代码模板
/*
KMP算法模板-ValenShi
最后修改:2019/9/26
使用说明:
1.字符串起始位置是1而不是0,修改可能会出错.
2.记得初始化nex与f数组
3.原串长度与模式串长度都在函数中用strlen更新,无需修改全局变量
*/
#include
using namespace std;
const int N = 1e6+10;
char s1[N],s2[N];
int nex[N],f[N];
void getNex(const char *s){
/*更新模式串s的nex数组*/
int len = strlen(s);
memset(nex,0,sizeof nex);
for(int i = 2,j = 0;i < len;i++){
while(j > 0 && s[i] != s[j+1]) j = nex[j];
if(s[i] == s[j+1]) j++;
nex[i] = j;
}
}
void getF(const char* S,const char *T){
/*求解 f 数组,S是目标串,T是模式串*/
memset(f,0,sizeof f);
int len1 = strlen(S),len2 = strlen(T);
for(int i = 1,j = 0;i < len1;i++){
while(j > 0 && (j == len2 || S[i] != T[j+1])) j = nex[j];
if(S[i] == T[j+1]) j++;
f[i] = j;
}
}
void solve(){
/*求解nex数组与f数组,并 按要求 输出答案*/
getNex(s2);
getF(s1,s2);
int len1 = strlen(s1)-1,len2 = strlen(s2)-1;
for(int i = 1;i <= len1;i++){
if(f[i] == len2) printf("%d\n",i-len2+1);
}
for(int i = 1;i <= len2;i++) printf("%d ",nex[i]);
}
int main(){
scanf("%s%s",s1+1,s2+1);
s1[0] = s2[0] = '#';//不然strlen函数无法使用
solve();
return 0;
}