本文记述笔者的算法学习心得,由于笔者不熟悉算法,如有错误请不吝指正。
随机句子问题
这几天得见一个算法题:生成随机句子。推测这道题可能来源于某搜索引擎公司的工作实践?
题意为:输入一个句子(即一个单词序列)和一个预期长度,需从输入句子随机抽取单词组成一个新句子(只是一个单词序列,不考虑语法),新句子的单词数量等于预期长度。随机抽取的具体规则为:随机取一个单词作为首个单词加入新句子,每当新句子加入一个单词,对于这个单词在输入句子中的所有出现位置,随机选择一个出现位置的下一个单词(即下邻单词)加入新句子,这一加入又会引起新单词的加入,由此类推,直到新句子达到预期长度,则输出新句子。也就是要实现一个函数String generateSentence(String sentence, int length)。
例如,输入sentence="this is a sentence it is not a good one and it is also bad" length=5,若"sentence"被选为首个单词,则下一个单词只有一个选择"it",然后"it"的下一个单词可以在两个位置选择,但这两个位置都刚好是"is",而"is"的下一个单词就可以选择"not"或"also",若选择"not",则下一个单词只有一个选择"a",此时凑足了5个单词,得到新句子"sentence it is not a"。
以上是此题的基础版,它还有一个难度加强版:对现有规则做一个修改,再给定一个输入m,首次随机取输入句子的连续m个单词加入新句子,之后每次仍加入一个单词,每当新句子加入一个单词,不再是找到这个单词在输入句子中的出现位置,而是对于新句子的“最后m个单词所形成的词组”在输入句子中的所有出现位置,随机选择一个出现位置的下邻单词加入新句子。基础版可看做m=1的特殊情况。
基础版的解法
解法基本上遵循如此结构:
- 将输入句子转换为单词序列(String -> String[],即tokenize),为新句子准备一个内存缓冲区。
- 随机选择首个单词,放入内存缓冲区。
- 循环搜索上一个被选择的单词在输入句子中的所有位置,随机选择一个位置,将其下邻单词放入内存缓冲区,直到新句子达到预期长度,则停止并输出新句子。
注意,把输入句子当作一个环,若搜索到输入句子的句尾单词,则其下邻单词是句首单词,即从句尾跳回了句首。为了在计算下邻位置时做环跳转的处理,我起初是这么写的(在越界时减一次句子总长):
int nextPos(int pos, String[] words) {
int next = pos + 1;
if (next >= words.length) {
next -= words.length;
}
return next;
}
后来听说取模就可以了,即:
int nextPos(int pos, String[] words) {
return (pos + 1) % words.length;
}
对于基础版,减法是OK的,但对于难度加强版呢?如果m>=输入句子长度(虽然这种输入不是很合理),则下邻位置可能跨环两次,需要减两次才正确。因此,取模确实更简洁优雅,更sound。
现在介绍三种算法实现。第一种是暴力搜索法,每次遍历输入句子,找到给定单词的所有位置。 第二种是优化搜索法,对暴力搜索法做了优化,每次不是找到给定单词的所有位置,而是从一个随机位置起步,找到给定单词的某一个位置。第三种是哈希索引法,为输入句子预构建一个哈希表,能快速查找任一个单词对应的下邻单词。
要提前声明,优化搜索法的随机选择可能是不公平的。例如,连续2个相同单词如"go go",从左到右的搜索几乎总会选中左边那个。这个问题能解决,每次随机决定搜索方向即可(从左到右或从右到左)。但是,连续3个相同单词如"go go go "呢?中间的单词被选中的概率微乎其微。我想过”素数跳跃搜索法“等方法,但都没有解决,因此优化搜索法无法成为主流算法。
暴力搜索法的实现代码
class Solution1 {
public static void main(String[] args) {
System.out.println(generateSentence(
"this is a sentence it is not a good one and it is also bad", 5));
}
public static String generateSentence(String sentence, int length) {
String[] words = sentence.split(" ");
if (words.length == 0) {
return sentence;
}
List newWords = new ArrayList<>();
int randIdx = new Random().nextInt(words.length);
String prev = words[randIdx];
newWords.add(prev);
for (int i = 1; i < length; i++) {
List nexts = searchNexts(words, prev);
int chosen = new Random().nextInt(nexts.size());
String chosenWord = words[nexts.get(chosen)];
newWords.add(chosenWord);
prev = chosenWord;
}
return String.join(" ", newWords);
}
// 在words中找到prev的所有出现位置,收集其下邻位置(nexts)
private static List searchNexts(String[] words, String prev) {
List nexts = new ArrayList<>();
for (int i = 0; i < words.length; i++) {
if (words[i].equals(prev)) {
nexts.add(nextPos(i, words));
}
}
return nexts;
}
private static int nextPos(int pos, String[] words) {
return (pos + 1) % words.length;
}
}
优化搜索法的实现代码
class Solution1 {
public static void main(String[] args) {
System.out.println(generateSentence(
"this is a sentence it is not a good one and it is also bad", 5));
}
public static String generateSentence(String sentence, int length) {
String[] words = sentence.split(" ");
if (words.length == 0) {
return sentence;
}
List newWords = new ArrayList<>();
int randIdx = new Random().nextInt(words.length);
String prev = words[randIdx];
newWords.add(prev);
for (int i = 1; i < length; i++) {
String chosenWord = randomNextWord(words, prev);
newWords.add(chosenWord);
prev = chosenWord;
}
return String.join(" ", newWords);
}
private static String randomNextWord(String[] words, String prev) {
int randomBeginIndex = new Random().nextInt(words.length);
for (int _i = 0; _i < words.length; _i++) {
int idx = (randomBeginIndex + _i) % words.length;
if (words[idx].equals(prev)) {
return words[nextPos(idx, words)];
}
}
return null;
}
private static int nextPos(int pos, String[] words) {
return (pos + 1) % words.length;
}
}
哈希索引法的实现代码
class Solution1 {
public static void main(String[] args) {
System.out.println(generateSentence(
"this is a sentence it is not a good one and it is also bad", 5));
}
public static String generateSentence(String sentence, int length) {
String[] words = sentence.split(" ");
if (words.length == 0) {
return sentence;
}
Table table = new Table(words);
List newWords = new ArrayList<>();
int randIdx = new Random().nextInt(words.length);
String prev = words[randIdx];
newWords.add(prev);
for (int i = 1; i < length; i++) {
String chosenWord = table.randomNextWord(prev);
newWords.add(chosenWord);
prev = chosenWord;
}
return String.join(" ", newWords);
}
private static int nextPos(int pos, String[] words) {
return (pos + 1) % words.length;
}
static class Table {
private Map> map = new HashMap<>();
Table(String[] words) {
for (int i = 0; i < words.length; i++) {
List nexts = map.computeIfAbsent(words[i], key -> new ArrayList<>());
nexts.add(words[nextPos(i, words)]);
}
}
String randomNextWord(String word) {
List nexts = map.get(word);
int chosen = new Random().nextInt(nexts.size());
return nexts.get(chosen);
}
}
}
哈希索引法引入一个Table结构来封装建表和查询的逻辑,不但理论上更高效,代码也更易理解。
算法复杂度怎么分析呢?设n=words.length,l=length,则:
- 暴力搜索法主要花时间在n次迭代中每次都要遍历搜索n个单词,时间复杂度为O(n^2 + l),空间复杂度为O(1)。
- 优化搜索法主要花时间在n次迭代中每次都要遍历搜索平均n/(k+1)个单词,k为被搜索单词的出现次数(只出现一次时k=1,则平均遍历n/2个单词,即半个句子),时间复杂度为O(n^2/k + l),空间复杂度为O(1)。k一般比较小,因此只是常数级优化,实际效果待测试。
- 哈希索引法主要花时间在构建一个大小为n的哈希表,时间复杂度为O(n + l),空间复杂度为O(n)。事实上,如果对于相同输入反复调用,由于哈希表可以重用,其时间成本摊销后可忽略不计,即O(1)。
难度加强版的解法
解法基本上遵循如此结构:
- 将输入句子转换为单词序列(String -> String[],即tokenize),为新句子准备一个内存缓冲区。
- 随机选择连续m个单词,放入内存缓冲区。
- 循环搜索新句子“最后m个单词所形成的词组”在输入句子中的所有位置,随机选择一个位置,将其下邻单词放入内存缓冲区,直到新句子达到预期长度,则停止并输出新句子。
若用m作为输入参数的变量名,岂不毫无含义?按m的语义,此变量可名为lookBack,即“回向查看的项数”(在此顺带一提,编译器的语法分析有一概念叫lookAhead,即“前向查看的项数”)。
现在,我们发现可以重构代码,使其更安全优雅:处理环跳转时,对nextPos取模是怕数组越界,既然如此,那就对数组访问都加一道安全防线,数组偏移量都取模不就好了么!
设计这个可复用的数组访问助手函数以代替nextPos函数:
String safeGetWord(String[] words, int index) {
return words[index % words.length];
}
现在再次实现三种算法。
暴力搜索法的实现代码
class Solution2 {
public static void main(String[] args) {
System.out.println(generateSentence(
"this is a sentence it is not a good one and it is also bad", 5, 2));
}
public static String generateSentence(String sentence, int length, int lookBack) {
String[] words = sentence.split(" ");
if (words.length == 0) {
return sentence;
}
if (lookBack > length) {
throw new IllegalArgumentException("lookBack exceeds length");
}
List newWords = new ArrayList<>();
generateLeading(newWords, words, lookBack);
for (int _i = lookBack; _i < length; _i++) {
List prevs = newWords.subList(newWords.size() - lookBack, newWords.size());
List nexts = searchNexts(words, prevs);
int chosen = new Random().nextInt(nexts.size());
String chosenWord = safeGetWord(words, nexts.get(chosen));
newWords.add(chosenWord);
}
return String.join(" ", newWords);
}
// 生成最初的几个单词
private static void generateLeading(List newWords, String[] words, int lookBack) {
int randIdx = new Random().nextInt(words.length);
for (int i = 0; i < lookBack; i++) {
newWords.add(safeGetWord(words, randIdx + i));
}
}
private static List searchNexts(String[] words, List prevs) {
List nexts = new ArrayList<>();
for (int i = 0; i < words.length; i++) {
// 试匹配一个词组
int matchedCount = 0;
for (int j = 0; j < prevs.size(); j++) {
if (!safeGetWord(words, i + j).equals(prevs.get(j))) {
matchedCount = -1;
break;
}
matchedCount++;
}
if (matchedCount == prevs.size()) {
nexts.add(i + prevs.size());
}
}
return nexts;
}
private static String safeGetWord(String[] words, int index) {
return words[index % words.length];
}
}
优化搜索法的实现代码
class Solution2 {
public static void main(String[] args) {
System.out.println(generateSentence(
"this is a sentence it is not a good one and it is also bad", 5, 2));
}
public static String generateSentence(String sentence, int length, int lookBack) {
String[] words = sentence.split(" ");
if (words.length == 0) {
return sentence;
}
if (lookBack > length) {
throw new IllegalArgumentException("lookBack exceeds length");
}
List newWords = new ArrayList<>();
generateLeading(newWords, words, lookBack);
for (int ig = lookBack; ig < length; ig++) {
List prevs = newWords.subList(newWords.size() - lookBack, newWords.size());
String chosenWord = randomNextWord(words, prevs);
newWords.add(chosenWord);
}
return String.join(" ", newWords);
}
// 生成最初的几个单词
private static void generateLeading(List newWords, String[] words, int lookBack) {
int randIdx = new Random().nextInt(words.length);
for (int i = 0; i < lookBack; i++) {
newWords.add(safeGetWord(words, randIdx + i));
}
}
private static String randomNextWord(String[] words, List prevs) {
int randomBeginIndex = new Random().nextInt(words.length);
for (int _i = 0; _i < words.length; _i++) {
int idx = randomBeginIndex + _i;
// 试匹配一个词组
int matchedCount = 0;
for (int j = 0; j < prevs.size(); j++) {
if (!safeGetWord(words, idx + j).equals(prevs.get(j))) {
matchedCount = -1;
break;
}
matchedCount++;
}
if (matchedCount == prevs.size()) {
return safeGetWord(words, idx + prevs.size());
}
}
return null;
}
private static String safeGetWord(String[] words, int index) {
return words[index % words.length];
}
}
哈希索引法的实现代码
class Solution2 {
public static void main(String[] args) {
System.out.println(generateSentence(
"this is a sentence it is not a good one and it is also bad", 5, 2));
}
public static String generateSentence(String sentence, int length, int lookBack) {
String[] words = sentence.split(" ");
if (words.length == 0) {
return sentence;
}
if (lookBack > length) {
throw new IllegalArgumentException("lookBack exceeds length");
}
Table table = new Table(words, lookBack);
List newWords = new ArrayList<>();
generateLeading(newWords, words, lookBack);
for (int ig = lookBack; ig < length; ig++) {
Phrase phrase = new Phrase(newWords.subList(newWords.size() - lookBack, newWords.size()));
String chosenWord = table.randomNextWord(phrase);
newWords.add(chosenWord);
}
return String.join(" ", newWords);
}
private static void generateLeading(List newWords, String[] words, int lookBack) {
int randIdx = new Random().nextInt(words.length);
for (int i = 0; i < lookBack; i++) {
newWords.add(safeGetWord(words, randIdx + i));
}
}
private static String safeGetWord(String[] words, int index) {
return words[index % words.length];
}
static class Phrase {
private List elements;
Phrase(List elements) {
Objects.requireNonNull(elements);
// TODO 应当拷贝一份以确保不可变性
this.elements = elements;
}
Phrase(int lookBack, String[] words, int beginIndex) {
elements = new ArrayList<>(lookBack);
for (int j = 0; j < lookBack; j++) {
elements.add(safeGetWord(words, beginIndex + j));
}
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof Phrase)) return false;
Phrase phrase = (Phrase) o;
return elements.equals(phrase.elements);
}
@Override
public int hashCode() {
return elements.hashCode();
}
}
static class Table {
private Map> map = new HashMap<>();
Table(String[] words, int lookBack) {
for (int i = 0; i < words.length; i++) {
Phrase phrase = new Phrase(lookBack, words, i);
List nexts = map.computeIfAbsent(phrase, key -> new ArrayList<>());
nexts.add(safeGetWord(words, i + lookBack));
}
}
String randomNextWord(Phrase phrase) {
List nexts = map.get(phrase);
int chosen = new Random().nextInt(nexts.size());
return nexts.get(chosen);
}
}
}
这一回,哈希索引法又引入了一个Phrase结构来作为HashMap key。为什么不直接用List
- 可读性更好。
- key应当是不可变的(immutable),若写入HashMap用的key值不同于查询用的key值,就会查询不到数据。Phrase是一个不可变对象,将List
封装了起来。 - Phrase的hashCode方法是可以修改的,未来可以优化它以提高性能。现在的hashCode是用List
实现的,算法全程的hashCode计算成本为O(mn),m即lookBack。String会缓存hashCode,因此主要成本在于将每个String的hashCode相加,其实性能还好。
算法复杂度怎么分析呢?设n=words.length,l=length,m=lookBack,则:
- 暴力搜索法,时间复杂度为O(mn^2 + l),空间复杂度为O(1)。
- 优化搜索法,时间复杂度为O(mn^2/k + l),空间复杂度为O(1)。
- 哈希索引法主要花时间在构建一个大小为n的索引表,时间复杂度为O(mn + l),空间复杂度为O(mn)。
向生产就绪进发
以上的算法达到生产就绪(production ready)了吗?让我们来考察。
正确性
优化搜索法由于某些情况不公平,在满足功能需求上有所折扣。哈希索引法既高效又易理解,容易正确实现,预计适合用于生产环境。
算法的实现正确吗?需要做功能测试。算法很适合单元测试,但是包含随机数,怎么稳定重现呢,更特别地说,怎么确保测到边界条件呢?
就不放代码了,只说最重要的问题:怎么稳定重现,乃至确保测到边界条件。
答案很简单:应当mock随机数。将用于生成随机数的代码抽成一个函数,在单元测试中将它mock为返回确定的数值。这是因为随机性质不需要放在一起测试,需要被测试的是其他性质。
随机性
注意,随机数的生成要引入熵,才有足够的随机性。以Java的Random伪随机数为例,若全程共用一个Random对象,就只是伪随机,若多次new Random(),由于中间的间隔时间可能随机波动,就引入了时间熵,得到真随机数。也可以全程共用一个SecureRandom对象,这个能提供真随机数,但是速度稍慢一些。
性能
哈希索引法在理论上快n倍,实际有多快呢?
实际性能如何,应当做性能测试,俗话说跑个分(benchmark)。
我们做一个小的性能测试,再做一个大的性能测试,以暴力搜索法为基准来评估性能。
小的性能测试(测试代码略):
将例句重复8遍得到一个8倍长的句子,其余参数不变(length=5,lookBack=2)。将算法预热5万遍,再执行20万遍,多次测试取平均值。(20万遍)花费时间如下:
基本版:
- 暴力搜索法:1358ms
- 优化搜索法:850ms
- 哈希索引法:(不重用哈希表)1212ms,(重用哈希表)699ms
难度加强版:
- 暴力搜索法:1186ms
- 优化搜索法:480ms
- 哈希索引法:(不重用哈希表)1032ms,(重用哈希表)382ms
结果让人惊讶,哈希索引法(不重用哈希表)没有显著快于基准(暴力搜索法)!优化搜索法则表现很好,几乎快1倍!
可能是数据量太小,没有发挥哈希表的优势,那么试一试更大数据量吧!
大的性能测试:
加大数据量,将例句随机重排词序,并重复100遍得到一个100倍长的句子,length=100,lookBack=2。将算法预热1千遍,再执行5千遍,多次测试取平均值。
测试代码要复杂一些,因此提供代码如下:
public static void main(String[] args) {
String sentence = "this is a sentence it is not a good one and it is also bad";
sentence = times(sentence, 100);
// 预热
for (int i = 0; i < 1000; i++) {
generateSentence(
sentence, 100, 2);
}
// 测试
long start = System.currentTimeMillis();
for (int i = 0; i < 5000; i++) {
generateSentence(
sentence, 100, 2);
}
long cost = System.currentTimeMillis() - start;
System.out.println(cost);
}
// 生成长句子
static String times(String input, int n) {
List buffer = new ArrayList<>();
List words = Arrays.asList(input.split(" "));
for (int i = 0; i < n; i++) {
Collections.shuffle(words);
buffer.addAll(words);
}
return String.join(" ", buffer);
}
(5千遍)花费时间如下:
基本版:
- 暴力搜索法:4600ms
- 优化搜索法:450ms
- 哈希索引法:(不重用哈希表)500ms ,(重用哈希表)405ms
难度加强版:
- 暴力搜索法:10000ms
- 优化搜索法:1000ms
- 哈希索引法:(不重用哈希表)1000ms (重用哈希表)420ms
总算发挥了哈希索引法的优势!
在这个数据量,优化搜索法仍然不亚于哈希索引法,而且空间复杂度更优!
如果能从理论上解释这一现象,那就更好了。怎么解释呢?因为是把例句(无论是否随机重排词序)重复100遍生成的输入句子,使得优化搜索法的k≥100,当然会性能很好啊。那么试一试大量生成随机单词,以降低k值:优化搜索法的性能降低了一半,哈希索引法不受影响。
顺便测试了对hashCode的优化,若Phrase只计算第一个单词的hashCode,当lookBack=2时性能无提升,当lookBack=20时性能提升8%。
结论
哈希索引法的高效性、高稳定性、易理解性,使其成为生产环境的首选。注意几个问题:空间复杂度、哈希表只能对相同输入重用、随机性、key不可变、hashCode可优化。
优化搜索法的有趣性,使其值得进一步研究。