编辑距离(Levenshtein)

前言

最近师兄参加招聘笔试的时候,遇到了一道问题,其实就是变种的编辑距离问题。

问题描述

英文单词拼写的时候可能会出现拼写错误的情况(typo)。下面的题目,我们尝试实现拼写纠错推荐的功能。
字串编辑距离是衡量字串间相似度的常见手段。

  1. 字串编辑距离:是指利用字符操作,把字符串A转换成字符串B所需要的最少操作数。
  2. 字串操作包括:删除字符(removal)、插入字符(insertion)、修改字符(substituation)。
  3. 使用以下规则对推荐纠错选项进行相速度排序。得分越高,认为相似度越低

只涉及到26个英文字符、不区分大小写。
删除字符(removal) 3分
插入字符(insertion) 3分
修改字符(substituation):
(q w e r t a s d f g z x c v) (y u i o p h j k l b n m)
以上两个分组内的字符修改1分,两个分组间字符修改2分。

输入

每行一个输入。空格分隔参数。第一个参数是目标单词(可能存在typo),后面若干个空格分割参数(个数不定)是字典单词,作为候选纠错项(也可能和第一个参数重复)。

输出

按照上面的纠错规则字符串相似度最小编辑距离进行排序,给出3个(如有)对应的纠错候选。如得分相同,则按照输入顺序进行排序。

样例

样例输入

slep slap sleep step shoe shop snap slep

样例输出

slep slap step

编辑距离(Levenshtein)算法

算法原理

本算法采用动态规划的思想,计算对于字符串s{1,2,3, ...,n} 到字符串t{1,2,3,...,m}的最短距离(最小代价)。
假设:

  1. 插入一个字符的代价为:a
  2. 删除一个字符的代价为:b
  3. 修改一个字符的代价为:c

对于某个状态:
若字符串t与字符串temp的最小代价为costt,temp = k
且字符串temp与字符串s只需要一次操作(插入、删除、修改字符)
则字符串s与字符串t的最小代价costs,t = k+min(a,b,c)
这其实就是状态转移方程了。

状态转移方程

现在定义 costi,j :长度为i的字符串到长度为j的字符串最小代价。
所以 costn,m = min(①,②,③)
① = costn-1,m + min(a,b)
② = costn,m-1 + min(a,b)
③ = costn-1,m-1 + c
对于① ,从n串到m串,n-1作为temp串,花费代价为 costn-1,m,比较插入和删除操作中较小的代价,实现从n-1串到n串;
对于② ,从n串到m串,m-1作为temp串,花费代价为 costn,m-1 + min(a,b),比较插入和删除操作中较小的代价,实现从m-1串到m串;
对于③,若 n-1串到m-1串代价为k, 则n串到m串代价为 k+c (修改一个字符)

边界条件:
cost0,0 = 0
cost0,1 = min(a,b)
cost1,0 = min(a,b)

求解方法---代价表

步骤 内容
1 构造一张大小为(N+1,M+1)的表,用于储存代价。
2 计算所有边界的代价
3 根据状态转移方程计算整张表

如s="abcd" ,t="acdef" ,插入代价为1,删除代价为1,修改代价为1。
1.构造一张5x6的表,如下所示:

a b c d
a
c
d
e
f

2.计算所有边界条件。

a b c d
0 1 2 3 4
a 1
c 2
d 3
e 4
f 5

3.根据状态转移方程计算整张表.。

a b c d
0 1 2 3 4
a 1 0 1 2 3
c 2 1 1 1 2
d 3 2 2 2 1
e 4 3 3 3 2
f 5 4 4 4 3

最终可以看出,s="abcd" ,t="acdef" 的代价是3,删除b,插入ef。计算结果如表中最后一项。

java代码

import java.util.Arrays;
import java.util.Comparator;
import java.util.Scanner;

public class WordDistance {

    private static int min(int one, int two) {
        int min;
        if (one <= two) min = one;
        else min = two;
        return min;
    }

    private static int min(int one, int two, int three) {
        return min(min(one, two), three);
    }

    public static int wd(String str1, String str2) {
        int cost_add = 3; //插入字符的代价
        int cost_del = 3; //删除字符的代价
        int[][] d; // 矩阵
        int n = str1.length();
        int m = str2.length();
        int i; // 遍历str1的
        int j; // 遍历str2的

        if (n == 0) {
            // str1是空串
            return m * min(cost_add, cost_del);
        }
        if (m == 0) {
            // str2是空串
            return n * min(cost_add, cost_del);
        }
        d = new int[n + 1][m + 1];
        // 初始化所有边界
        for (i = 0; i <= n; i++) { // 初始化第一列
            d[i][0] = i * min(cost_add, cost_del);
        }
        for (j = 0; j <= m; j++) { // 初始化第一行
            d[0][j] = j * min(cost_add, cost_del);
        }
        for (i = 1; i <= n; i++) { // 计算整张表
            char ch1 = str1.charAt(i - 1);
            for (j = 1; j <= m; j++) {
                char ch2 = str2.charAt(j - 1);
                // 状态转移方程
                d[i][j] = min(d[i - 1][j] + min(cost_add, cost_del),
                        d[i][j - 1] + min(cost_add, cost_del),
                        d[i - 1][j - 1] + get_cost(ch1, ch2));
            }
        }
        return d[n][m];
    }

    public static int get_cost(char a, char b) {
        // 题目中,修改字符的代价不同
        String s1 = "qwertasdfgzxcv";
        String s2 = "yuiophjklbnm";
        int cost = 0;
        if (a == b) return cost;
        if (s1.contains(a + "")) {
            if (s1.contains(b + "")) {
                // a b 都在s1组
                cost = 1;
            } else {
                // a b 分别在s1 s2组
                cost = 2;
            }
        } else {
            if (s2.contains(b + "")) {
                // a b 分别在s2 s1组
                cost = 2;
            } else {
                // a b 都在s2组
                cost = 1;
            }
        }
        return cost;
    }

    static class node implements Comparable {
        // 绑定单词和距离代价
        int dis;
        String word;

        node(int dis, String word) {
            this.dis = dis;
            this.word = word;
        }

        @Override
        public int compareTo(Integer o) {
            return this.dis - o;
        }
    }

    public static void main(String[] args) {
        // 从system.in接收字符串
//        Scanner sc = new Scanner(System.in);
//        String s_in = sc.nextLine();
        String s_in = "slep slap sleep step shoe shop snap slep";
        String[] s_splits = s_in.split(" ");
        String target = s_splits[0];
        String[] dictionary = Arrays.copyOfRange(s_splits, 1, s_splits.length);
        int[] distance = new int[dictionary.length];
        node[] node_list = new node[dictionary.length];
        for (int i = 0; i < distance.length; i++) {
            distance[i] = wd(target, dictionary[i]);
            node_list[i] = new node(distance[i], dictionary[i]);
        }
        Arrays.sort(node_list, new Comparator() {
            @Override
            public int compare(node o1, node o2) {
                return o1.dis - o2.dis;
            }
        });
        for (int i = 0; i < node_list.length; i++) {
            if (i > 2) break; //限制输出3个
            System.out.print(node_list[i].word + " ");
        }
    }
}

本人java水平很菜,输出格式是师兄随手写的,我贴来用用。

python代码

class WordDistance:
    cost_add = 3
    cost_del = 3

    @staticmethod
    def get_cost(a, b):
        if a == b:
            return 0
        if a in "qwertasdfgzxcv":
            if b in "qwertasdfgzxcv":
                cost = 1
            else:
                cost = 2
        else:
            if b in "qwertasdfgzxcv":
                cost = 2
            else:
                cost = 1
        return cost

    def __init__(self, ):
        pass

    def wd(self, str1, str2):
        n = len(str1)
        m = len(str2)
        # 空串情况
        if n == 0:
            return m * min(self.cost_add, self.cost_del)
        if m == 0:
            return n * min(self.cost_add, self.cost_del)
        d = [[0 for _ in range(m + 1)] for _ in range(n + 1)]
        for i in range(n + 1):
            # 第一列
            d[i][0] = i * min(self.cost_add, self.cost_del)
        for j in range(m + 1):
            d[0][j] = j * min(self.cost_add, self.cost_del)
        for i in range(1, n + 1):
            ch1 = str1[i - 1]
            for j in range(1, m + 1):
                ch2 = str2[j - 1]
                d[i][j] = min(
                    d[i - 1][j] + min(self.cost_add, self.cost_del),
                    d[i][j - 1] + min(self.cost_add, self.cost_del),
                    d[i - 1][j - 1] + self.get_cost(ch1, ch2)
                )
        return d[n][m]


if __name__ == '__main__':
    s_in = "slep slap sleep step shoe shop snap slep"
    target = s_in.split(" ", 1)[0]
    dictionary = s_in.split(" ")[1:]
    wd = WordDistance()

    wd_dict = {word: wd.wd(target, word) for word in dictionary}
    wd_sort_list = sorted(wd_dict.items(), key=lambda x: x[1])

    for i in range(len(wd_sort_list)):
        if i > 2: break
        print(wd_sort_list[i][0], end=" ")

参考博客

java文本相似度计算(Levenshtein Distance算法(中文翻译:编辑距离算法))----代码和详解

你可能感兴趣的:(编辑距离(Levenshtein))