生产就绪的算法:以随机句子为例

本文记述笔者的算法学习心得,由于笔者不熟悉算法,如有错误请不吝指正。

随机句子问题

这几天得见一个算法题:生成随机句子。推测这道题可能来源于某搜索引擎公司的工作实践?

题意为:输入一个句子(即一个单词序列)和一个预期长度,需从输入句子随机抽取单词组成一个新句子(只是一个单词序列,不考虑语法),新句子的单词数量等于预期长度。随机抽取的具体规则为:随机取一个单词作为首个单词加入新句子,每当新句子加入一个单词,对于这个单词在输入句子中的所有出现位置,随机选择一个出现位置的下一个单词(即下邻单词)加入新句子,这一加入又会引起新单词的加入,由此类推,直到新句子达到预期长度,则输出新句子。也就是要实现一个函数String generateSentence(String sentence, int length)。

例如,输入sentence="this is a sentence it is not a good one and it is also bad" length=5,若"sentence"被选为首个单词,则下一个单词只有一个选择"it",然后"it"的下一个单词可以在两个位置选择,但这两个位置都刚好是"is",而"is"的下一个单词就可以选择"not"或"also",若选择"not",则下一个单词只有一个选择"a",此时凑足了5个单词,得到新句子"sentence it is not a"。

以上是此题的基础版,它还有一个难度加强版:对现有规则做一个修改,再给定一个输入m,首次随机取输入句子的连续m个单词加入新句子,之后每次仍加入一个单词,每当新句子加入一个单词,不再是找到这个单词在输入句子中的出现位置,而是对于新句子的“最后m个单词所形成的词组”在输入句子中的所有出现位置,随机选择一个出现位置的下邻单词加入新句子。基础版可看做m=1的特殊情况。

基础版的解法

解法基本上遵循如此结构:

  1. 将输入句子转换为单词序列(String -> String[],即tokenize),为新句子准备一个内存缓冲区。
  2. 随机选择首个单词,放入内存缓冲区。
  3. 循环搜索上一个被选择的单词在输入句子中的所有位置,随机选择一个位置,将其下邻单词放入内存缓冲区,直到新句子达到预期长度,则停止并输出新句子。

注意,把输入句子当作一个环,若搜索到输入句子的句尾单词,则其下邻单词是句首单词,即从句尾跳回了句首。为了在计算下邻位置时做环跳转的处理,我起初是这么写的(在越界时减一次句子总长):

int nextPos(int pos, String[] words) {
  int next = pos + 1;
  if (next >= words.length) {
    next -= words.length;
  }
  return next;
}

后来听说取模就可以了,即:

int nextPos(int pos, String[] words) {
  return (pos + 1) % words.length;
}

对于基础版,减法是OK的,但对于难度加强版呢?如果m>=输入句子长度(虽然这种输入不是很合理),则下邻位置可能跨环两次,需要减两次才正确。因此,取模确实更简洁优雅,更sound。

现在介绍三种算法实现。第一种是暴力搜索法,每次遍历输入句子,找到给定单词的所有位置。 第二种是优化搜索法,对暴力搜索法做了优化,每次不是找到给定单词的所有位置,而是从一个随机位置起步,找到给定单词的某一个位置。第三种是哈希索引法,为输入句子预构建一个哈希表,能快速查找任一个单词对应的下邻单词。

要提前声明,优化搜索法的随机选择可能是不公平的。例如,连续2个相同单词如"go go",从左到右的搜索几乎总会选中左边那个。这个问题能解决,每次随机决定搜索方向即可(从左到右或从右到左)。但是,连续3个相同单词如"go go go "呢?中间的单词被选中的概率微乎其微。我想过”素数跳跃搜索法“等方法,但都没有解决,因此优化搜索法无法成为主流算法。

暴力搜索法的实现代码

class Solution1 {
  public static void main(String[] args) {
    System.out.println(generateSentence(
      "this is a sentence it is not a good one and it is also bad", 5));
  }

  public static String generateSentence(String sentence, int length) {
    String[] words = sentence.split(" ");
    if (words.length == 0) {
      return sentence;
    }

    List newWords = new ArrayList<>();

    int randIdx = new Random().nextInt(words.length);
    String prev = words[randIdx];
    newWords.add(prev);

    for (int i = 1; i < length; i++) {
      List nexts = searchNexts(words, prev);
      int chosen = new Random().nextInt(nexts.size());
      String chosenWord = words[nexts.get(chosen)];
      newWords.add(chosenWord);
      prev = chosenWord;
    }

    return String.join(" ", newWords);
  }

  // 在words中找到prev的所有出现位置,收集其下邻位置(nexts)
  private static List searchNexts(String[] words, String prev) {
    List nexts = new ArrayList<>();
    for (int i = 0; i < words.length; i++) {
      if (words[i].equals(prev)) {
        nexts.add(nextPos(i, words));
      }
    }

    return nexts;
  }

  private static int nextPos(int pos, String[] words) {
    return (pos + 1) % words.length;
  }
}

优化搜索法的实现代码

class Solution1 {
  public static void main(String[] args) {
    System.out.println(generateSentence(
      "this is a sentence it is not a good one and it is also bad", 5));
  }

  public static String generateSentence(String sentence, int length) {
    String[] words = sentence.split(" ");
    if (words.length == 0) {
      return sentence;
    }

    List newWords = new ArrayList<>();

    int randIdx = new Random().nextInt(words.length);
    String prev = words[randIdx];
    newWords.add(prev);

    for (int i = 1; i < length; i++) {
      String chosenWord = randomNextWord(words, prev);
      newWords.add(chosenWord);
      prev = chosenWord;
    }

    return String.join(" ", newWords);
  }

  private static String randomNextWord(String[] words, String prev) {
    int randomBeginIndex = new Random().nextInt(words.length);

    for (int _i = 0; _i < words.length; _i++) {
      int idx = (randomBeginIndex + _i) % words.length;
      if (words[idx].equals(prev)) {
        return words[nextPos(idx, words)];
      }
    }

    return null;
  }

  private static int nextPos(int pos, String[] words) {
    return (pos + 1) % words.length;
  }
}

哈希索引法的实现代码

class Solution1 {
  public static void main(String[] args) {
    System.out.println(generateSentence(
      "this is a sentence it is not a good one and it is also bad", 5));
  }

  public static String generateSentence(String sentence, int length) {
    String[] words = sentence.split(" ");
    if (words.length == 0) {
      return sentence;
    }

    Table table = new Table(words);

    List newWords = new ArrayList<>();

    int randIdx = new Random().nextInt(words.length);
    String prev = words[randIdx];
    newWords.add(prev);

    for (int i = 1; i < length; i++) {
      String chosenWord = table.randomNextWord(prev);
      newWords.add(chosenWord);
      prev = chosenWord;
    }

    return String.join(" ", newWords);
  }

  private static int nextPos(int pos, String[] words) {
    return (pos + 1) % words.length;
  }

  static class Table {
    private Map> map = new HashMap<>();

    Table(String[] words) {
      for (int i = 0; i < words.length; i++) {
        List nexts = map.computeIfAbsent(words[i], key -> new ArrayList<>());
        nexts.add(words[nextPos(i, words)]);
      }
    }

    String randomNextWord(String word) {
      List nexts = map.get(word);
      int chosen = new Random().nextInt(nexts.size());
      return nexts.get(chosen);
    }
  }
}

哈希索引法引入一个Table结构来封装建表和查询的逻辑,不但理论上更高效,代码也更易理解。
算法复杂度怎么分析呢?设n=words.length,l=length,则:

  1. 暴力搜索法主要花时间在n次迭代中每次都要遍历搜索n个单词,时间复杂度为O(n^2 + l),空间复杂度为O(1)。
  2. 优化搜索法主要花时间在n次迭代中每次都要遍历搜索平均n/(k+1)个单词,k为被搜索单词的出现次数(只出现一次时k=1,则平均遍历n/2个单词,即半个句子),时间复杂度为O(n^2/k + l),空间复杂度为O(1)。k一般比较小,因此只是常数级优化,实际效果待测试。
  3. 哈希索引法主要花时间在构建一个大小为n的哈希表,时间复杂度为O(n + l),空间复杂度为O(n)。事实上,如果对于相同输入反复调用,由于哈希表可以重用,其时间成本摊销后可忽略不计,即O(1)。

难度加强版的解法

解法基本上遵循如此结构:

  1. 将输入句子转换为单词序列(String -> String[],即tokenize),为新句子准备一个内存缓冲区。
  2. 随机选择连续m个单词,放入内存缓冲区。
  3. 循环搜索新句子“最后m个单词所形成的词组”在输入句子中的所有位置,随机选择一个位置,将其下邻单词放入内存缓冲区,直到新句子达到预期长度,则停止并输出新句子。

若用m作为输入参数的变量名,岂不毫无含义?按m的语义,此变量可名为lookBack,即“回向查看的项数”(在此顺带一提,编译器的语法分析有一概念叫lookAhead,即“前向查看的项数”)。

现在,我们发现可以重构代码,使其更安全优雅:处理环跳转时,对nextPos取模是怕数组越界,既然如此,那就对数组访问都加一道安全防线,数组偏移量都取模不就好了么!
设计这个可复用的数组访问助手函数以代替nextPos函数:

String safeGetWord(String[] words, int index) {
  return words[index % words.length];
}

现在再次实现三种算法。

暴力搜索法的实现代码

class Solution2 {
  public static void main(String[] args) {
    System.out.println(generateSentence(
      "this is a sentence it is not a good one and it is also bad", 5, 2));
  }

  public static String generateSentence(String sentence, int length, int lookBack) {
    String[] words = sentence.split(" ");
    if (words.length == 0) {
      return sentence;
    }
    if (lookBack > length) {
      throw new IllegalArgumentException("lookBack exceeds length");
    }

    List newWords = new ArrayList<>();

    generateLeading(newWords, words, lookBack);

    for (int _i = lookBack; _i < length; _i++) {
      List prevs = newWords.subList(newWords.size() - lookBack, newWords.size());
      List nexts = searchNexts(words, prevs);

      int chosen = new Random().nextInt(nexts.size());
      String chosenWord = safeGetWord(words, nexts.get(chosen));
      newWords.add(chosenWord);
    }

    return String.join(" ", newWords);
  }

  // 生成最初的几个单词
  private static void generateLeading(List newWords, String[] words, int lookBack) {
    int randIdx = new Random().nextInt(words.length);
    for (int i = 0; i < lookBack; i++) {
      newWords.add(safeGetWord(words, randIdx + i));
    }
  }

  private static List searchNexts(String[] words, List prevs) {
    List nexts = new ArrayList<>();

    for (int i = 0; i < words.length; i++) {
      // 试匹配一个词组
      int matchedCount = 0;

      for (int j = 0; j < prevs.size(); j++) {
        if (!safeGetWord(words, i + j).equals(prevs.get(j))) {
          matchedCount = -1;
          break;
        }

        matchedCount++;
      }

      if (matchedCount == prevs.size()) {
        nexts.add(i + prevs.size());
      }
    }

    return nexts;
  }

  private static String safeGetWord(String[] words, int index) {
    return words[index % words.length];
  }
}

优化搜索法的实现代码

class Solution2 {
  public static void main(String[] args) {
    System.out.println(generateSentence(
      "this is a sentence it is not a good one and it is also bad", 5, 2));
  }

  public static String generateSentence(String sentence, int length, int lookBack) {
    String[] words = sentence.split(" ");
    if (words.length == 0) {
      return sentence;
    }
    if (lookBack > length) {
      throw new IllegalArgumentException("lookBack exceeds length");
    }

    List newWords = new ArrayList<>();

    generateLeading(newWords, words, lookBack);

    for (int ig = lookBack; ig < length; ig++) {
      List prevs = newWords.subList(newWords.size() - lookBack, newWords.size());
      String chosenWord = randomNextWord(words, prevs);
      newWords.add(chosenWord);
    }

    return String.join(" ", newWords);
  }

  // 生成最初的几个单词
  private static void generateLeading(List newWords, String[] words, int lookBack) {
    int randIdx = new Random().nextInt(words.length);
    for (int i = 0; i < lookBack; i++) {
      newWords.add(safeGetWord(words, randIdx + i));
    }
  }

  private static String randomNextWord(String[] words, List prevs) {
    int randomBeginIndex = new Random().nextInt(words.length);

    for (int _i = 0; _i < words.length; _i++) {
      int idx = randomBeginIndex + _i;
      // 试匹配一个词组
      int matchedCount = 0;

      for (int j = 0; j < prevs.size(); j++) {
        if (!safeGetWord(words, idx + j).equals(prevs.get(j))) {
          matchedCount = -1;
          break;
        }

        matchedCount++;
      }

      if (matchedCount == prevs.size()) {
        return safeGetWord(words, idx + prevs.size());
      }
    }

    return null;
  }

  private static String safeGetWord(String[] words, int index) {
    return words[index % words.length];
  }
}

哈希索引法的实现代码

class Solution2 {
  public static void main(String[] args) {
    System.out.println(generateSentence(
      "this is a sentence it is not a good one and it is also bad", 5, 2));
  }

  public static String generateSentence(String sentence, int length, int lookBack) {
    String[] words = sentence.split(" ");
    if (words.length == 0) {
      return sentence;
    }
    if (lookBack > length) {
      throw new IllegalArgumentException("lookBack exceeds length");
    }

    Table table = new Table(words, lookBack);

    List newWords = new ArrayList<>();

    generateLeading(newWords, words, lookBack);

    for (int ig = lookBack; ig < length; ig++) {
      Phrase phrase = new Phrase(newWords.subList(newWords.size() - lookBack, newWords.size()));
      String chosenWord = table.randomNextWord(phrase);
      newWords.add(chosenWord);
    }

    return String.join(" ", newWords);
  }

  private static void generateLeading(List newWords, String[] words, int lookBack) {
    int randIdx = new Random().nextInt(words.length);
    for (int i = 0; i < lookBack; i++) {
      newWords.add(safeGetWord(words, randIdx + i));
    }
  }

  private static String safeGetWord(String[] words, int index) {
    return words[index % words.length];
  }

  static class Phrase {
    private List elements;

    Phrase(List elements) {
      Objects.requireNonNull(elements);
      // TODO 应当拷贝一份以确保不可变性
      this.elements = elements;
    }

    Phrase(int lookBack, String[] words, int beginIndex) {
      elements = new ArrayList<>(lookBack);
      for (int j = 0; j < lookBack; j++) {
        elements.add(safeGetWord(words, beginIndex + j));
      }
    }

    @Override
    public boolean equals(Object o) {
      if (this == o) return true;
      if (!(o instanceof Phrase)) return false;
      Phrase phrase = (Phrase) o;
      return elements.equals(phrase.elements);
    }

    @Override
    public int hashCode() {
      return elements.hashCode();
    }
  }

  static class Table {
    private Map> map = new HashMap<>();

    Table(String[] words, int lookBack) {
      for (int i = 0; i < words.length; i++) {
        Phrase phrase = new Phrase(lookBack, words, i);

        List nexts = map.computeIfAbsent(phrase, key -> new ArrayList<>());
        nexts.add(safeGetWord(words, i + lookBack));
      }
    }

    String randomNextWord(Phrase phrase) {
      List nexts = map.get(phrase);
      int chosen = new Random().nextInt(nexts.size());
      return nexts.get(chosen);
    }
  }
}

这一回,哈希索引法又引入了一个Phrase结构来作为HashMap key。为什么不直接用List作为key呢?有两个原因:

  1. 可读性更好。
  2. key应当是不可变的(immutable),若写入HashMap用的key值不同于查询用的key值,就会查询不到数据。Phrase是一个不可变对象,将List封装了起来。
  3. Phrase的hashCode方法是可以修改的,未来可以优化它以提高性能。现在的hashCode是用List实现的,算法全程的hashCode计算成本为O(mn),m即lookBack。String会缓存hashCode,因此主要成本在于将每个String的hashCode相加,其实性能还好。

算法复杂度怎么分析呢?设n=words.length,l=length,m=lookBack,则:

  1. 暴力搜索法,时间复杂度为O(mn^2 + l),空间复杂度为O(1)。
  2. 优化搜索法,时间复杂度为O(mn^2/k + l),空间复杂度为O(1)。
  3. 哈希索引法主要花时间在构建一个大小为n的索引表,时间复杂度为O(mn + l),空间复杂度为O(mn)。

向生产就绪进发

以上的算法达到生产就绪(production ready)了吗?让我们来考察。

正确性

优化搜索法由于某些情况不公平,在满足功能需求上有所折扣。哈希索引法既高效又易理解,容易正确实现,预计适合用于生产环境。

算法的实现正确吗?需要做功能测试。算法很适合单元测试,但是包含随机数,怎么稳定重现呢,更特别地说,怎么确保测到边界条件呢?
就不放代码了,只说最重要的问题:怎么稳定重现,乃至确保测到边界条件。
答案很简单:应当mock随机数。将用于生成随机数的代码抽成一个函数,在单元测试中将它mock为返回确定的数值。这是因为随机性质不需要放在一起测试,需要被测试的是其他性质。

随机性

注意,随机数的生成要引入熵,才有足够的随机性。以Java的Random伪随机数为例,若全程共用一个Random对象,就只是伪随机,若多次new Random(),由于中间的间隔时间可能随机波动,就引入了时间熵,得到真随机数。也可以全程共用一个SecureRandom对象,这个能提供真随机数,但是速度稍慢一些。

性能

哈希索引法在理论上快n倍,实际有多快呢?
实际性能如何,应当做性能测试,俗话说跑个分(benchmark)。
我们做一个小的性能测试,再做一个大的性能测试,以暴力搜索法为基准来评估性能。

小的性能测试(测试代码略):
将例句重复8遍得到一个8倍长的句子,其余参数不变(length=5,lookBack=2)。将算法预热5万遍,再执行20万遍,多次测试取平均值。(20万遍)花费时间如下:

  1. 基本版:

    1. 暴力搜索法:1358ms
    2. 优化搜索法:850ms
    3. 哈希索引法:(不重用哈希表)1212ms,(重用哈希表)699ms
  2. 难度加强版:

    1. 暴力搜索法:1186ms
    2. 优化搜索法:480ms
    3. 哈希索引法:(不重用哈希表)1032ms,(重用哈希表)382ms

结果让人惊讶,哈希索引法(不重用哈希表)没有显著快于基准(暴力搜索法)!优化搜索法则表现很好,几乎快1倍!
可能是数据量太小,没有发挥哈希表的优势,那么试一试更大数据量吧!

大的性能测试:
加大数据量,将例句随机重排词序,并重复100遍得到一个100倍长的句子,length=100,lookBack=2。将算法预热1千遍,再执行5千遍,多次测试取平均值。
测试代码要复杂一些,因此提供代码如下:

public static void main(String[] args) {
  String sentence = "this is a sentence it is not a good one and it is also bad";
  sentence = times(sentence, 100);

  // 预热
  for (int i = 0; i < 1000; i++) {
    generateSentence(
      sentence, 100, 2);
  }

  // 测试
  long start = System.currentTimeMillis();
  for (int i = 0; i < 5000; i++) {
    generateSentence(
      sentence, 100, 2);
  }
  long cost = System.currentTimeMillis() - start;
  System.out.println(cost);
}

// 生成长句子
static String times(String input, int n) {
  List buffer = new ArrayList<>();
  List words = Arrays.asList(input.split(" "));
  for (int i = 0; i < n; i++) {
    Collections.shuffle(words);
    buffer.addAll(words);
  }
  return String.join(" ", buffer);
}

(5千遍)花费时间如下:

  1. 基本版:

    1. 暴力搜索法:4600ms
    2. 优化搜索法:450ms
    3. 哈希索引法:(不重用哈希表)500ms ,(重用哈希表)405ms
  2. 难度加强版:

    1. 暴力搜索法:10000ms
    2. 优化搜索法:1000ms
    3. 哈希索引法:(不重用哈希表)1000ms (重用哈希表)420ms

总算发挥了哈希索引法的优势!
在这个数据量,优化搜索法仍然不亚于哈希索引法,而且空间复杂度更优!
如果能从理论上解释这一现象,那就更好了。怎么解释呢?因为是把例句(无论是否随机重排词序)重复100遍生成的输入句子,使得优化搜索法的k≥100,当然会性能很好啊。那么试一试大量生成随机单词,以降低k值:优化搜索法的性能降低了一半,哈希索引法不受影响。

顺便测试了对hashCode的优化,若Phrase只计算第一个单词的hashCode,当lookBack=2时性能无提升,当lookBack=20时性能提升8%。

结论

哈希索引法的高效性、高稳定性、易理解性,使其成为生产环境的首选。注意几个问题:空间复杂度、哈希表只能对相同输入重用、随机性、key不可变、hashCode可优化。
优化搜索法的有趣性,使其值得进一步研究。

你可能感兴趣的:(算法)