字符串匹配之Rabin-Karp算法

假设我们输入两个字符串s1、s2,其中s1的长度必须大于等于s2长度。我们要求出s2在s1中出现了多少次。

例如:
输入样例
abcaabcaabcbcabc
abc
输出样例
4

2、算法原理

​ 这道题如果采用枚举暴力的方法时间复杂度会达到O(mn)。枚举方法的思路很简单,利用2个指针从头到尾扫描2个字符串:

//枚举
private static void solve01(String s1, String s2) {
    int len_s1 = s1.length();
    int len_s2 = s2.length();
    int ans = 0;//记录答案
    for(int i = 0; i < len_s1; i++) { //扫描s1字符串
        if(s1.charAt(i) == s2.charAt(0)) {//当s2第一个字符与s1字符串的第一个字符匹配上的时候
            int j = 0;
            for(j = 1; j < len_s2; j++) {//利用第二个指针控制两个字符串的下标验证是否满足要求
                if(s1.charAt(i + j) != s2.charAt(j)) { //不满足则退出循环
                    break;
                }
             }
            if(j == len_s2) { //当j == len_s2的时候则说明第二个循环从头扫到尾,即满足题意。
                ans++;
            //为了验证是否正确,进行了下标打印(下标0代表字符串的第一位)
                System.out.println("match:" + i);
             }
        }
     }
}

枚举方法的时间复杂度计算:
​假设s1长度为m,s2长度为n,扫描s1需要O(m),扫描s2需要O(n),上述代码最坏情况下是O(mn)。

对于字符串较短的时候,可以利用枚举暴力实现,可是当字符串很长的时候,则会超时,因为计算机1s大概可以进行10 ^ 8次方规模的运算,当n和m超过10 ^ 4次方的时候,就会超时,而竞赛的数据往往是10 ^ 9甚至更大,这时需要更优的算法解决,所以下面将介绍Rabin-Karp算法,如何将O(nm)时间复杂度降到O(n + m)。

RabinKarp算法的思路:可以对s2字符串进行特定的处理,得出一个整数值,这个值在术语上称为哈希值,只要能求出这个值,只需要对s1字符串的每个字符截取长度为s2的子字符串进行求哈希值,比对s1子字符串和s2 字符串的哈希值是否相等来判断是否相等。

那么如何求出这个字符串的哈希值呢?
每个字符都有一个ASCII码相对应,那么就可以利用进制的计算原理来求出哈希值:
​例如:
​我们可以把它看成是31进制,则abc可以表示为:
​31 ^ 2 * a + 31 ^ 1 * b + 31 ^ 0 * c
这样子就可以把这个字符串的哈希值求出来了。但这样用代码实现比较难施展开,我们可以对这条式进行变换: (((0 * 31) + a ) * 31 + b) * 31 + c = 31 ^ 2 * a + 31 ^ 1 * b + 31 ^ 0 * c
这样子就可以通过循环计算哈希值了,代码如下:

//计算Hash值
private static long hash(String s) {
    int len = s.length();
    long hash = 0;
    long send = 31;//设置种子数为31,代表是31进制计算;
    /**
    * 假设s长度为3
    * 则s的hash值计算公式为:
    * send^2 * s2[0] + send * s2[1] + send * s2[2]
    * 等价于  = (((send * 0) + s2[0]) * send) + s2[1]) * send + s2[2]
    */
    for(int i = 0; i < len; i++) {
        hash = send * hash + s.charAt(i);
    }
    return hash;
}
那么在知道如何计算一个字符串的哈希值之后,我们就可以通过比较2个字符串的哈希值来判断是否相等了。

代码如下:

//通过计算哈希值
private static void solve02(String s1, String s2) {
    //1.先计算s2的hash值
    long hash_s2 = hash(s2);
    //2.通过遍历求出每个字符作为初始位置且长度为s2的长度的字符串的hash值
    int len_s2 = s2.length();
    int len_s1 = s1.length();
    for(int i = 0; i + len_s2 <= len_s1; i++) {
        long hash_s1 = hash(s1.substring(i, i + len_s2));//计算s1中长度为s2的字串哈希值
        if(hash_s1 == hash_s2) {
            System.out.println("match:" + i);//i表示s1串的下标,下标从0开始
        }
    }
    //3.比较两个hash值是否相同
}

可是我们发现在上述的代码中,时间复杂度依然是O(mn),为什么呢?

一开始需要求一次s2的哈希值,时间复杂度是O(n)

在s1字符串里每长度为n字符串都要求一次哈希值所以是O(nm)

所以最后还是O(mn)复杂度。这相对于枚举其实也没有什么优化呀!

其实接下来才是RabinKarp算法的精髓!!!!!!!!!!!

RabinKarp算法的精髓就是利用滚动哈希来算出解,什么是滚动哈希,一开始我也是一脸懵,其实滚动哈希还是挺容易理解的,相对于动态规划,哈哈!

我们不妨想一下,比如,给出一个字符串s1为abcd,s2为abc,根据上述哈希值计算,计算出了下标为0到下标为2的子串的哈希值,按照上述方法,接下来就是计算bcd的哈希值。滚动哈希的精髓就在此,我们已知abc的哈希值,要求的是bcd的哈希值,那么abc和bcd这两个字符串存在着公共字串,那就是bc,那我们不就可以通过前一个字符串的哈希值乘上种子数(进制)减去 种子数 ^ length(abc) + d

计算abc哈希值:31 ^ 2 * a + 31 ^ 1 * b + 31 ^ 0 * c

计算abcd哈希值:31 ^ 3 * a + 31 ^ 2 * b + 31 ^ 1 * c + 31 ^ 0 * d

可以得到 abc→abcd = (abc) * 31 + d

那么我们就能通过逆推得出bcd的哈希值

即 bcd = (abc ) * 31 + d - a * 31 ^ 3

代码:

private static void solve03(String s1, String s2) {
    // 1.先计算s2的hash值
    long hash_s2 = hash(s2);
    int len_s2 = s2.length();
    int len_s1 = s1.length();
    long[] hash_s1 = new long[len_s1 - len_s2 + 1];//s1所有的字串长度为s1的长度 - s2的长度 + 1
    hash_s1[0] = hash(s1.substring(0, len_s2));//先算出s1中第一个长度为s2的字串的哈希值
    int send = 31;
    int ans = 0;
    for (int i = len_s2; i < len_s1; i++) {//从已计算过哈希值的后一个字符开始
        char newChar = s1.charAt(i);//记录需要添加的新字符
        char oldChar = s1.charAt(i - len_s2);//记录需要删除的旧字符
        long v = (long) (hash_s1[i - len_s2] * send - oldChar * Math.pow(send, len_s2) + newChar) % Long.MAX_VALUE;//根据上述逆推公式
        hash_s1[i - len_s2 + 1] = v; //赋值
    }
    for(int i = 0; i < hash_s1.length; i++) {
        if (hash_s1[i] == hash_s2) { // 当两个哈希值相同的时候,即满足题意。
            ans++;
            // 为了验证是否正确,进行了下标打印(下标0代表字符串的第一位)
            System.out.println("match:" + i);
        }
    }
}

根据代码我们可以计算出它的时间复杂度

计算s2的哈希值需要O(n);

计算s1的每个字串的哈希值需要O(m),为什么不再是O(mn), 因为每个字串的哈希值是从上一个字串哈希值通过公式计算出来的值,不再需要遍历字串的长度计算哈希值,所以每一次计算字串的时间复杂度为O(1);

最后扫描一次求出的s1字串哈希值需要O(m);

所以一共的时间复杂度为O(n + 2m),其中2m的2为常数项,可以省略,最后的时间复杂度为O(n + m)!!!!!!

完整代码:

public class Main {

    public static void main(String[] args) {
        String s1 = "abcaabcaabcbcabc";
        String s2 = "abc";
        solve01(s1, s2);
        solve02(s1, s2);
        solve03(s1, s2);
    }

    private static void solve03(String s1, String s2) {
        System.out.println("--------------->滚动哈希:");
        // 1.先计算s2的hash值
        long hash_s2 = hash(s2);
        int len_s2 = s2.length();
        int len_s1 = s1.length();
        long[] hash_s1 = new long[len_s1 - len_s2 + 1];//s1所有的字串长度为s1的长度 - s2的长度 + 1
        hash_s1[0] = hash(s1.substring(0, len_s2));//先算出s1中第一个长度为s2的字串的哈希值
        int send = 31;
        int ans = 0;
        for (int i = len_s2; i < len_s1; i++) {//从已计算过哈希值的后一个字符开始
            char newChar = s1.charAt(i);
            char oldChar = s1.charAt(i - len_s2);
            long v = (long) (hash_s1[i - len_s2] * send - oldChar * Math.pow(send, len_s2) + newChar) % Long.MAX_VALUE;
            hash_s1[i - len_s2 + 1] = v; 
        }
        for(int i = 0; i < hash_s1.length; i++) {
            if (hash_s1[i] == hash_s2) { // 当j == len_s2的时候则说明第二个循环从头扫到尾,即满足题意。
                ans++;
                // 为了验证是否正确,进行了下标打印(下标0代表字符串的第一位)
                System.out.println("match:" + i);
            }
        }
    }

    private static void solve01(String s1, String s2) {
        System.out.println("--------------->枚举:");
        int len_s1 = s1.length();
        int len_s2 = s2.length();
        int ans = 0;// 记录答案
        for (int i = 0; i < len_s1; i++) { // 扫描s1字符串
            if (s1.charAt(i) == s2.charAt(0)) { // 当s2第一个字符与s1字符串的第一个字符匹配上的时候
                int j = 0;
                for (j = 1; j < len_s2; j++) { // 利用第二个指针控制两个字符串的下标验证是否满足要求
                    if (s1.charAt(i + j) != s2.charAt(j)) { // 不满足则退出循环
                        break;
                    }
                }
                if (j == len_s2) { // 当j == len_s2的时候则说明第二个循环从头扫到尾,即满足题意。
                    ans++;
                    // 为了验证是否正确,进行了下标打印(下标0代表字符串的第一位)
                    System.out.println("match:" + i);
                }
            }
        }
    }

    private static void solve02(String s1, String s2) {
        System.out.println("--------------->未优化哈希计算:");
        // 1.先计算s2的hash值
        long hash_s2 = hash(s2);
        // 2.通过遍历求出每个字符作为初始位置且长度为s2的长度的字符串的hash值
        int len_s2 = s2.length();
        int len_s1 = s1.length();
        for (int i = 0; i + len_s2 <= len_s1; i++) {
            long hash_s1 = hash(s1.substring(i, i + len_s2));
            if (hash_s1 == hash_s2) {
                System.out.println("match:" + i);// i表示s1串的下标,下标从0开始
            }
        }
        // 3.比较两个hash值是否相同
    }

    // 计算Hash值
    private static long hash(String s) {
        int len = s.length();
        long hash = 0;
        long send = 31;// 设置种子数为31;
        /**
         * 假设s长度为3 则s2的hash值计算公式为: send^2 * s2[0] + send * s2[1] + send * s2[2] 等价于 =
         * (((send * 0) + s2[0]) * send) + s2[1]) * send + s2[2]
         */
        for (int i = 0; i < len; i++) {
            hash = send * hash + s.charAt(i);
        }
        return hash;
    }
}

最后我相信你们会产生疑问,万一不同的字符串算出来的哈希值一样呢?

是的,会出现这种情况,但是根据大量数据得出结论:使用100000个不同字符串产生的冲突数,大概在0~3波动,使用100百万不同的字符串,冲突数大概在110+范围波动。所以,如果想要确保万无一失,可以在判断哈希值相等的时候再补一刀,就是判断两个串是否相同,但这样会降低效率,不过在程序比赛中,一般不会出现波动,可以省略!!!

最后,写这些文章就是想通过这样的方法加深自己对算法的印象,因为我是个健忘的人!感谢收看。

你可能感兴趣的:(字符串匹配之Rabin-Karp算法)