Stream编程之N-gram实现

N-gram是常用的概率语言模型,可以通过已有语料推断语句结构的合理性,在自然语言处理中有着广泛的应用,N-gram的概念就不多说了,网上有大把的教程,想了解的可以自己搜。
Stream是java8的新特性,java8已经发布3年有余了,不知道大家在实际中应用的有多少,工作原因这两年java代码写的比较少,就拿N-gram算法来练练手,个人感觉stream还是很适合做文字处理这种事情的,流式编程写起来还是很方便的。
下面来看看具体实现:

package nlp.ngram;

import java.io.BufferedReader;
import java.io.File;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class Ngram {

    private static Map trainingData;

    static {
        String path = ".zhuxian.txt";
        trainingData = initTrainingData(path);//初始化训练数据,加载到内存
    }

    /**
     * 计算输入句子的合理性概率
     * @param sentence 要评估的语句
     * @param n N-gram的N
     * @return 合理性概率
     */
    private static float getProbability(String sentence, int n) {
        final String sen = "@" + sentence + "#";
        return Stream
            .iterate(0, i -> ++i)
            .limit(sen.length() - n + 1)
            .map(start -> sen.substring(start, start + n))
            .map(s -> {
                float nu = (float) (null == trainingData.get(s) ? 1 : trainingData.get(s).get() + 1);
                float de = (float) (null == trainingData.get(s.substring(0, s.length() - 1)) ? 1 : trainingData.get
                    (s.substring(0, s.length() - 1)).get() + 1);
                System.out.println(s + "/" + s.substring(0, s.length() - 1) + " " + nu / de);
                return nu / de;
            })
            .reduce(1f, (f1, f2) -> f1 * f2);
    }

    /**
     * 把训练数据处理过后加载到内存,统计每个分词的出现频次
     * @param path 训练集路径
     * @return 统计数据map
     */
    private static Map initTrainingData(String path) {
        return readFileOrDir(path)
            .map(s -> s.replaceAll("[”“\\w\\s《》.::*‘’、\"<>\\[\\]^`~]", ""))//去掉文字里的无意义字符,这里只处理中文
            .flatMap(s -> Stream.of(s.split("[,,。;;!!??]")))//分割句子
            .filter(s -> !"".equals(s))//去掉空行
//          .peek(System.out::println)
            .map(s -> "@" + s + "#")//加上句首和句尾标记
            .flatMap(s -> Stream
                .iterate(1, i -> ++i)//支持的N-gram的N为1、2、3、4
                .limit(s.length() > 4 ? 4 : s.length())//N-gram的N最大为4,太大了内存容易爆,实际应用中4基本就够了
                .parallel()
                .flatMap(n -> Stream
                    .iterate(0, i -> ++i)
                    .limit(s.length() - n + 1)
                    .parallel()
                    .map(start -> s.substring(start, start + n))//分割句子为n个字的集合
                )
            )
            .collect(Collectors.toConcurrentMap(o -> o,
                o -> new AtomicInteger(1), (e1, e2) -> {
                e1.incrementAndGet();
                return e1;
            }));
    }

    /**
     * 文件读取工具,以行为单位输出
     * @param path 文件路径
     * @return Stream lines 流
     */
    private static Stream readFileOrDir(String path) {
        File file = new File(path);
        if (file.isDirectory()) {
            String[] paths = file.list((dir, name) -> !name.startsWith("."));
            assert paths != null;
            return Arrays.stream(paths)
                .flatMap(p -> readFileOrDir(path + File.separator + p));
        } else {
            try {
                return new BufferedReader(new java.io.FileReader(path)).lines();
            } catch (Exception e) {
                System.err.println("read file " + path + " error!" + e.getMessage());
                return Stream.empty();
            }
        }
    }

    public static void main(String[] args) throws Exception {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        br.lines().forEach(s -> System.out.println(getProbability(s, 3)));
    }
}

这里用java流式编程可以大大的减少代码量,一个方法处理语料,一个方法计算概率就完事了,对于cpu密集型的任务还可以用多线程 (parallel) 来加速处理速度,不过并发的坑就得自己填了。
这里的实现相当基础,没有分词,准确性会低很多,可优化的空间还很大,还有数据的平滑 (smoothing) 处理这里就不展开讨论了,这里的实现只是简单的把没出现的词出现的次数设为1,实际使用中要结合实际设计数据平滑算法,有时间再填填优化的部分。

你可能感兴趣的:(Stream编程之N-gram实现)