最新JAVA的NLP工具DJL

零、其他:NLP工具包
LingPipe

是alias公司开发的一款自然语言处理软件包。

主题分类(Top Classification)
命名实体识别(Named Entity Recognition)
词性标注(Part-of Speech Tagging)
句题检测(Sentence Detection)
查询拼写检查(Query Spell Checking)
兴趣短语检测(Interseting Phrase Detection)
聚类(Clustering)
字符语言建模(Character Language Modeling)
医学文献下载/解析/索引(MEDLINE Download, Parsing and Indexing)
数据库文本挖掘(Database Text Mining)
中文分词(Chinese Word Segmentation)
情感分析(Sentiment Analysis)
语言辨别(Language Identification)

HanLP

HanLP是一系列模型与算法组成的NLP工具包,目标是普及自然语言处理在生产环境中的应用。

FudanNLP

FNLP主要是为中文自然语言处理而开发的工具包,也包含为实现这些任务的机器学习算法和数据集。 本工具包及其包含数据集使用LGPL3.0许可证。

    信息检索: 文本分类 新闻聚类
    中文处理: 中文分词 词性标注 实体名识别 关键词抽取 依存句法分析 时间短语识别
    结构化学习: 在线学习 层次分类 聚类

Apache OpenNLP

OpenNLP支持最常见的NLP任务:

例如标记化,句子分段,词性标记,命名实体提取,分块,解析,语言检测和共指解析
1
Stanford CoreNLP

斯坦福大学开发的自然语言处理工具套件,包括词性标注、命名实体识别、共指消解等 NLP 任务:


一、简介
开源库以Java构建和部署深度学习、编写一次即可在任何地方运行。使用DJL开发模型并在您选择的引擎上运行。直观的API使用本机Java概念并抽象化了深度学习所涉及的复杂性。引入您自己的模型,或使用我们库中的最新模型在几分钟内进行部署。

二、开源地址:
https://github.com/awslabs/djl

三、例子或者用法

1、Single-shot object detection example
2、Train your first model
3、Image classification example
4、Transfer learning example
5、Train SSD model example
6、Multi-threaded inference example
7、Bert question and answer example
8、Instance segmentation example
9、Pose estimation example
10、Action recognition example
11、Multi-label dataset training example

四、官网地址:

https://djl.ai/

五、代码如下:
1、依赖:


         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
    4.0.0
   
        org.springframework.boot
        spring-boot-starter-parent
        2.2.5.RELEASE
       
   

    com.xxxx
    bigdata
    0.0.1-SNAPSHOT
    bigdata
    Demo project for Spring Boot

   
        1.8
        Hoxton.SR1
   

   
       
            org.springframework.boot
            spring-boot-starter-web
       

       
            org.springframework.cloud
            spring-cloud-starter-aws
       

       
            ai.djl
            examples
            0.3.0
       

       
            ai.djl
            api
            0.3.0
       

       
            org.projectlombok
            lombok
            true
       

       
            ai.djl
            basicdataset
            0.3.0
       

       
            ai.djl
            model-zoo
            0.3.0
       

       
            ai.djl.mxnet
            mxnet-model-zoo
            0.3.0
       

       
            org.springframework.boot
            spring-boot-starter-test
            test
           
               
                    org.junit.vintage
                    junit-vintage-engine
               

           

       

   

   
       
           
                org.springframework.cloud
                spring-cloud-dependencies
                ${spring-cloud.version}
                pom
                import
           

       

   

   
       
           
                org.springframework.boot
                spring-boot-maven-plugin
           

       

   



2、问答案例:

package com.xxxx.bigdata.nlputils;

import ai.djl.Application;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.mxnet.zoo.nlp.qa.QAInput;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * An example of inference using BertQA.
 *
 *

See:
 *
 *


     *  
  • the jupyter
     *       demo
    with more information about BERT.
     *  
  • the  *       href="https://github.com/awslabs/djl/blob/master/examples/docs/BERT_question_and_answer.md">docs
     *       for information about running this example.
     *

 */
public final class BertQaInference {

    private static final Logger logger = LoggerFactory.getLogger(BertQaInference.class);

    private BertQaInference() {}

    public static void main(String[] args) throws IOException, TranslateException, ModelException {
        String answer = BertQaInference.predict();
        logger.info("Answer: {}", answer);
    }

    public static String predict() throws IOException, TranslateException, ModelException {
        String question = "When did BBC Japan start broadcasting?";
        String paragraph =
                "BBC Japan was a general entertainment Channel.\n"
                        + "Which operated between December 2004 and April 2006.\n"
                        + "It ceased operations after its Japanese distributor folded.";

        QAInput input = new QAInput(question, paragraph, 384);
        logger.info("Paragraph: {}", input.getParagraph());
        logger.info("Question: {}", input.getQuestion());

        Criteria criteria =
                Criteria.builder()
                        .optApplication(Application.NLP.QUESTION_ANSWER)
                        .setTypes(QAInput.class, String.class)
                        .optFilter("backbone", "bert")
                        .optFilter("dataset", "book_corpus_wiki_en_uncased")
                        .optProgress(new ProgressBar())
                        .build();

        try (ZooModel model = ModelZoo.loadModel(criteria)) {
            try (Predictor predictor = model.newPredictor()) {
                return predictor.predict(input);
            }
        }
    }
}


3、训练模型

package com.xxxx.bigdata.nlputils;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.basicdataset.CaptchaDataset;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.examples.training.util.Arguments;
import ai.djl.examples.training.util.ExampleTrainingResult;
import ai.djl.examples.training.util.TrainingUtils;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.Dataset.Usage;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.SimpleCompositeLoss;
import ai.djl.training.loss.SoftmaxCrossEntropyLoss;
import ai.djl.training.util.ProgressBar;
import java.io.IOException;
import java.nio.file.Paths;

/**
 * An example of training a CAPTCHA solving model.
 *
 *

See this  * href="https://github.com/awslabs/djl/blob/master/examples/docs/train_captcha.md">doc for
 * information about this example.
 */
public final class TrainCaptcha {

    private TrainCaptcha() {}

    public static void main(String[] args) throws Exception{
        TrainCaptcha.runExample(args);
    }

    public static ExampleTrainingResult runExample(String[] args)
            throws Exception {
        Arguments arguments = Arguments.parseArgs(args);

        try (Model model = Model.newInstance()) {
            model.setBlock(getBlock());

            // get training and validation dataset
            RandomAccessDataset trainingSet = getDataset(Usage.TRAIN, arguments);
            RandomAccessDataset validateSet = getDataset(Usage.VALIDATION, arguments);

            // setup training configuration
            DefaultTrainingConfig config = setupTrainingConfig(arguments);

            ExampleTrainingResult result;
            try (Trainer trainer = model.newTrainer(config)) {
                trainer.setMetrics(new Metrics());

                Shape inputShape =
                        new Shape(1, 1, CaptchaDataset.IMAGE_HEIGHT, CaptchaDataset.IMAGE_WIDTH);

                // initialize trainer with proper input shape
                trainer.initialize(inputShape);

                TrainingUtils.fit(
                        trainer,
                        arguments.getEpoch(),
                        trainingSet,
                        validateSet,
                        arguments.getOutputDir(),
                        "captcha");

                result = new ExampleTrainingResult(trainer);
            }
            model.save(Paths.get(arguments.getOutputDir()), "captcha");
            return result;
        }
    }

    private static DefaultTrainingConfig setupTrainingConfig(Arguments arguments) {
        SimpleCompositeLoss loss = new SimpleCompositeLoss();
        for (int i = 0; i < CaptchaDataset.CAPTCHA_LENGTH; i++) {
            loss.addLoss(new SoftmaxCrossEntropyLoss("loss_digit_" + i), i);
        }

        DefaultTrainingConfig config =
                new DefaultTrainingConfig(loss)
                        .optDevices(Device.getDevices(arguments.getMaxGpus()))
                        .addTrainingListeners(
                                TrainingListener.Defaults.logging(arguments.getModelDir(),arguments.getBatchSize(),arguments.getEpoch(),arguments.getMaxGpus(),arguments.getOutputDir()));

        for (int i = 0; i < CaptchaDataset.CAPTCHA_LENGTH; i++) {
            config.addEvaluator(new Accuracy("acc_digit_" + i, i));
        }

        return config;
    }

    private static RandomAccessDataset getDataset(Dataset.Usage usage, Arguments arguments)
            throws IOException {
        CaptchaDataset dataset =
                CaptchaDataset.builder()
                        .optUsage(usage)
                        .setSampling(arguments.getBatchSize(), true)
                        .optMaxIteration(arguments.getMaxIterations())
                        .build();
        dataset.prepare(new ProgressBar());
        return dataset;
    }

    private static Block getBlock() {
        Block resnet =
                ResNetV1.builder()
                        .setNumLayers(50)
                        .setImageShape(
                                new Shape(
                                        1, CaptchaDataset.IMAGE_HEIGHT, CaptchaDataset.IMAGE_WIDTH))
                        .setOutSize(CaptchaDataset.CAPTCHA_OPTIONS * CaptchaDataset.CAPTCHA_LENGTH)
                        .build();

        return new SequentialBlock()
                .add(resnet)
                .add(
                        resnetOutputList -> {
                            NDArray resnetOutput = resnetOutputList.singletonOrThrow();
                            NDList splitOutput =
                                    resnetOutput
                                            .reshape(
                                                    -1,
                                                    CaptchaDataset.CAPTCHA_LENGTH,
                                                    CaptchaDataset.CAPTCHA_OPTIONS)
                                            .split(CaptchaDataset.CAPTCHA_LENGTH, 1);

                            NDList output = new NDList(CaptchaDataset.CAPTCHA_LENGTH);
                            for (NDArray outputDigit : splitOutput) {
                                output.add(outputDigit.squeeze(1));
                            }
                            return output;
                        });
    }
}

你可能感兴趣的:(算法,java,算法)