数据库djl_认识深度Java库(DJL)

数据库djl

重要要点

  • 开发人员可以使用Java和他们最喜欢的IDE构建,训练和部署机器学习(ML)和深度学习(DL)模型
  • DJL简化了深度学习(DL)框架的使用,目前支持Apache MXNet
  • DJL的开源性质对于工具包及其用户应该是互惠互利的
  • DJL与引擎无关,这意味着开发人员只需编写一次代码即可在任何引擎上运行
  • Java开发人员在尝试使用DJL之前应该了解ML生命周期和通用ML术语。

亚马逊的DJL是一种深度学习工具包,用于在Java 中原生开发机器学习(ML)和深度学习(DL)模型,同时简化了深度学习框架的使用。 DJL是re:Invent 2019的及时开放源代码工具包,提供了一组高级API来训练,测试和运行推理。 Java开发人员可以开发自己的模型,也可以利用数据科学家从Python中使用Java代码开发的经过预先训练的模型。

DJL通过与引擎和深度学习框架无关,坚守Java的座右铭:“编写一次,随处运行(WORA)”。 一旦在任何引擎上运行,开发人员就可以编写代码。 DJL当前为Apache MXNet提供了一种实现,该ML引擎简化了深度神经网络的开发。 DJL API使用JNA(Java本机访问)来调用相应的Apache MXNet操作。 DJL协调基础架构管理,基于硬件配置提供自动CPU / GPU检测,以确保良好的性能。

DJL API通过抽象常用功能来开发模型,从而使Java开发人员可以利用现有知识来简化向ML的过渡。 为了了解DJL的实际效果,让我们使用开发鞋类分类模型作为一个简单示例。

机器学习生命周期

遵循机器学习生命周期来生成鞋类分类模型。 ML生命周期不同于传统的软件开发生命周期,它包含六个具体步骤:

  1. 获取数据
  2. 清理并准备数据
  3. 产生模型
  4. 评估模型
  5. 部署模型
  6. 从模型获得预测(或推论)

生命周期的最终结果是一个机器学习模型,可以查询该模型并返回答案(或预测)。

数据库djl_认识深度Java库(DJL)_第1张图片

模型只是数据中趋势和模式的数学表示。 好的数据是所有机器学习项目的基础。

在步骤1中,从信誉良好的来源获取数据。 在第2步中,将数据清理,转换并以机器可以学习的格式放置。 清理和转换过程通常是机器学习生命周期中最耗时的部分。 DJL通过提供使用翻译器预处理图像的功能,使开发人员可以简化此过程。 翻译人员可以执行诸如根据预期参数调整图像大小或将图像从彩色转换为灰度的任务。

过渡到机器学习的开发人员通常会低估清理和转换数据所需的时间,因此翻译员是快速启动该过程的好方法。 在训练过程的第3步中,机器学习算法对数据进行多次遍历(或历时),然后对它们进行研究,以尝试学习不同类型的鞋类。 发现的与鞋类有关的趋势和样式存储在模型中。 当评估模型以确定模型在识别鞋类方面的能力时,第4步是训练的一部分。 如果发现错误,则将其纠正。 在步骤5中,将模型部署到生产环境。 模型投入生产后,第6步允许模型被其他系统使用。

通常,可以将模型动态加载到您的代码中,或通过基于REST的HTTPS端点进行访问。

数据

鞋类分类模型是一种多类分类计算机视觉(CV)模型,使用监督学习进行训练,该模型将鞋类分为四个类别标签之一:靴子,凉鞋,鞋子或拖鞋。 监督学习必须包括已经用您要预测的目标(或答案)标记的数据; 这是机器学习的方式。

鞋类分类模型的数据源是德克萨斯大学奥斯汀分校提供的UTZappos50k数据集,可免费用于学术,非商业用途。 鞋子数据集包含从Zappos.com收集的50,025张带标签的目录图像。

数据库djl_认识深度Java库(DJL)_第2张图片

鞋类数据保存在本地,并使用DJL的ImageFolder数据集加载,该数据集从本地文件夹中检索图像。

// identify the location of the training data
String trainingDatasetRoot = "src/test/resources/imagefolder/train";

// identify the location of the validation data
String validateDatasetRoot = "src/test/resources/imagefolder/validate";

//create training data ImageFolder dataset
ImageFolder trainingDataset = initDataset(trainingDatasetRoot);

//create validation data ImageFolder dataset
ImageFolder validateDataset = initDataset(validateDatasetRoot);

在本地构建数据时,我并没有深入到UTZappos50k数据集所确定的最细粒度的级别,例如脚踝,膝盖高,小腿中部,膝盖上方等。靴子的分类标签。 我的本地数据处于最高分类级别,其中仅包括靴子,凉鞋,鞋子和拖鞋。

数据库djl_认识深度Java库(DJL)_第3张图片

用DJL术语来说,数据集只是保存训练数据。 有一些数据集实现可用于下载数据(基于您提供的URL),提取数据并自动将数据分为训练和验证集。

自动分离是一个有用的功能,因为从不使用与训练模型相同的数据来验证模型的性能非常重要。 模型使用训练数据集来查找鞋类数据中的趋势和模式。 验证数据集用于通过对鞋类进行分类的模型准确性的无偏估计来限定模型的性能。

如果使用与训练时相同的数据对模型进行验证,则我们对模型进行鞋分类的信心将会大大降低,因为正在使用已经看到的数据对模型进行测试。 在现实世界中,老师不会使用学习指南中提供的完全相同的问题来测试学生,因为这无法衡量学生对材料的真实了解或理解; 随后,机器学习模型也适用相同的概念。

训练

现在,我们已将鞋类数据分为训练和验证集,现在让我们使用神经网络来训练(或产生)模型。

public final class Training extends AbstractTraining {

     . . .

     @Override
     protected void train(Arguments arguments) throws IOException {

          // identify the location of the training data
          String trainingDatasetRoot = "src/test/resources/imagefolder/train";

          // identify the location of the validation data
          String validateDatasetRoot = "src/test/resources/imagefolder/validate";

          //create training data ImageFolder dataset
          ImageFolder trainingDataset = initDataset(trainingDatasetRoot);

          //create validation data ImageFolder dataset
          ImageFolder validateDataset = initDataset(validateDatasetRoot);

          . . .

          try (Model model = Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH)) {
             TrainingConfig config = setupTrainingConfig(loss);

             try (Trainer trainer = model.newTrainer(config)) {
                 trainer.setMetrics(metrics);

                 trainer.setTrainingListener(this);

                 Shape inputShape = new Shape(1, 3, NEW_HEIGHT, NEW_WIDTH);

                 // initialize trainer with proper input shape
                 trainer.initialize(inputShape);

                 //find the patterns in data
                 fit(trainer, trainingDataset, validateDataset, "build/logs/training");

                 //set model properties
                 model.setProperty("Epoch", String.valueOf(EPOCHS));
                 model.setProperty("Accuracy", String.format("%.2f", getValidationAccuracy()));

                // save the model after done training for inference later
                //model saved as shoeclassifier-0000.params
                model.save(Paths.get(modelParamsPath), modelParamsName);
             }
          }
     }

 }

第一步是通过调用Models.getModel(NUM_OF_OUTPUT,NEW_HEIGHT,NEW_WIDTH)获得模型实例。 深度学习是机器学习的一种形式,它使用神经网络来训练模型。 神经网络以人脑中的神经元为模型。 神经元是将信息(或数据)传输到其他细胞的细胞。

ResNet-50是经常用于图像分类的神经网络。 50表示原始输入数据和最终预测之间存在50层学习(或神经元)。 getModel()方法创建一个空模型,构造一个ResNet-50神经网络,并将该神经网络设置为模型。

public class Models {
   public static ai.djl.Model getModel(int numOfOutput, int height, int width) {
       //create new instance of an empty model
       ai.djl.Model model = ai.djl.Model.newInstance();

       //Block is a composable unit that forms a neural network; combine them
       //like Lego blocks to form a complex network
       Block resNet50 =
               //construct the network
               new ResNetV1.Builder()
                       .setImageShape(new Shape(3, height, width))
                       .setNumLayers(50)
                       .setOutSize(numOfOutput)
                       .build();

       //set the neural network to the model
       model.setBlock(resNet50);
       return model;
   }
}

下一步是通过调用model.newTrainer(config)方法来设置和配置Trainer。 通过调用setupTrainingConfig(loss)方法来初始化config对象,该方法设置训练配置(或超参数)以确定如何训练网络。

接下来的步骤允许我们通过设置以下内容来向Trainer添加功能:

  • 使用度量 trainer.setMetrics(metrics)
  • 使用trainer.setTrainingListener( this )的训练侦听器
  • 使用trainer.initialize(inputShape)正确的输入形状

Metrics在培训期间收集并报告关键绩效指标(KPI),可用于分析和监视培训绩效和稳定性。 下一步是通过调用fit(trainer,trainingDataset,validateDataset, “ build / logs / training”方法来开始训练过程,该方法将迭代训练数据并存储在模型中找到的模式。 训练结束时,将使用model.save(Paths.get(modelParamsPath),modelParamsName)方法在本地保存性能良好且经过验证的模型工件。

培训过程中报告的指标如下所示。 请注意,随着每个时代(或每个阶段),模型的准确性都会提高; 第9个阶段的最终训练准确性为90%。

推理

现在我们已经生成了模型,可以将其用于对我们不知道其分类(或目标)的新数据执行推断(或预测)。

private Classifications predict() throws IOException, ModelException, TranslateException  {
   //the location to the model saved during training
   String modelParamsPath = "build/logs";

   //the name of the model set during training
   String modelParamsName = "shoeclassifier";

   // the path of image to classify
   String imageFilePath = "src/test/resources/slippers.jpg";

   //Load the image file from the path
   BufferedImage img = BufferedImageUtils.fromFile(Paths.get(imageFilePath));

   //holds the probability score per label
   Classifications predictResult;

   try (Model model = Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH)) {
       //load the model
       model.load(Paths.get(modelParamsPath), modelParamsName);

       //define a translator for pre and post processing
       Translator translator = new MyTranslator();

       //run the inference using a Predictor
       try (Predictor predictor = model.newPredictor(translator)) {
           predictResult = predictor.predict(img);
       }
   }

   return predictResult;
}

在设置了模型和要分类的图像的必要路径之后,请使用Models.getModel(NUM_OF_OUTPUT,NEW_HEIGHT,NEW_WIDTH)方法获取空的模型实例然后使用model.load(Paths.get(modelParamsPath),modelParamsName对其进行初始化)方法。 这将加载上一步中训练的模型。

接下来,使用model.newPredictor(translator)方法使用指定的Translator初始化Predictor 。 用DJL术语来说, 翻译器提供模型预处理和后处理功能。 例如,对于CV模型,需要将图像重塑为灰度。 译者可以做到这一点。 Predictor允许我们使用predictor.predict(img)方法对加载的模型进行推断,并传入图像进行分类。

此示例显示了单个预测,但DJL还支持批量预测。 推论存储在predictResult中 ,其中包含每个标签的概率估计。

推断(每个图像)及其对应的概率得分如下所示。

图片 机率分数
数据库djl_认识深度Java库(DJL)_第4张图片

[信息]-[
类别:“ 0”,概率: 0.98985
等级:“ 1”,概率:0.00225
等级:“ 2”,概率:0.00224
等级:“ 3”,概率:0.00564
]
0类代表靴子 ,其概率得分为98.98%

数据库djl_认识深度Java库(DJL)_第5张图片

[信息]-[
等级:“ 0”,概率:0.02111
等级:“ 1”,概率: 0.76524
等级:“ 2”,概率:0.01159
等级:“ 3”,概率:0.20204
]
1类代表凉鞋 ,概率得分为76.52%

数据库djl_认识深度Java库(DJL)_第6张图片

[信息]-[
类别:“ 0”,概率:0.05523
类别:“ 1”,概率:0.0417
等级:“ 2”,概率: 0.87900
等级:“ 3”,概率:0.05158
]
2类代表鞋子的概率得分为87.90%

数据库djl_认识深度Java库(DJL)_第7张图片

[信息]-[
类别:“ 0”,概率:0.00003
等级:“ 1”,概率:0.01133
等级:“ 2”,概率:0.00179
等级:“ 3”,概率: 0.98682
]
3类代表拖鞋 ,概率得分为98.68%

DJL提供了本机Java开发经验,其功能与其他Java库一样。 这些API旨在指导开发人员以最佳实践完成深度学习任务。 在开始DJL之前,需要对ML生命周期有一个很好的了解。 如果您不熟悉ML,请阅读概述,或者从InfoQ的文章系列开始,这是针对软件开发人员的机器学习简介 。 了解了生命周期和常见的ML术语后,开发人员可以快速掌握DJL的API。

亚马逊已经开源了DJL,有关该工具包的更多详细信息可以在DJL网站和Java Library API Specification页面上找到。 可以查看鞋类分类模型的代码,以进一步探索示例。

翻译自: https://www.infoq.com/articles/djl-deep-learning-java/?topicPageSponsorship=c1246725-b0a7-43a6-9ef9-68102c8d48e1

数据库djl

你可能感兴趣的:(神经网络,大数据,python,机器学习,人工智能)