数据库djl
亚马逊的DJL是一种深度学习工具包,用于在Java 中原生开发机器学习(ML)和深度学习(DL)模型,同时简化了深度学习框架的使用。 DJL是re:Invent 2019的及时开放源代码工具包,提供了一组高级API来训练,测试和运行推理。 Java开发人员可以开发自己的模型,也可以利用数据科学家从Python中使用Java代码开发的经过预先训练的模型。
DJL通过与引擎和深度学习框架无关,坚守Java的座右铭:“编写一次,随处运行(WORA)”。 一旦在任何引擎上运行,开发人员就可以编写代码。 DJL当前为Apache MXNet提供了一种实现,该ML引擎简化了深度神经网络的开发。 DJL API使用JNA(Java本机访问)来调用相应的Apache MXNet操作。 DJL协调基础架构管理,基于硬件配置提供自动CPU / GPU检测,以确保良好的性能。
DJL API通过抽象常用功能来开发模型,从而使Java开发人员可以利用现有知识来简化向ML的过渡。 为了了解DJL的实际效果,让我们使用开发鞋类分类模型作为一个简单示例。
遵循机器学习生命周期来生成鞋类分类模型。 ML生命周期不同于传统的软件开发生命周期,它包含六个具体步骤:
生命周期的最终结果是一个机器学习模型,可以查询该模型并返回答案(或预测)。
在步骤1中,从信誉良好的来源获取数据。 在第2步中,将数据清理,转换并以机器可以学习的格式放置。 清理和转换过程通常是机器学习生命周期中最耗时的部分。 DJL通过提供使用翻译器预处理图像的功能,使开发人员可以简化此过程。 翻译人员可以执行诸如根据预期参数调整图像大小或将图像从彩色转换为灰度的任务。
过渡到机器学习的开发人员通常会低估清理和转换数据所需的时间,因此翻译员是快速启动该过程的好方法。 在训练过程的第3步中,机器学习算法对数据进行多次遍历(或历时),然后对它们进行研究,以尝试学习不同类型的鞋类。 发现的与鞋类有关的趋势和样式存储在模型中。 当评估模型以确定模型在识别鞋类方面的能力时,第4步是训练的一部分。 如果发现错误,则将其纠正。 在步骤5中,将模型部署到生产环境。 模型投入生产后,第6步允许模型被其他系统使用。
通常,可以将模型动态加载到您的代码中,或通过基于REST的HTTPS端点进行访问。
鞋类分类模型是一种多类分类计算机视觉(CV)模型,使用监督学习进行训练,该模型将鞋类分为四个类别标签之一:靴子,凉鞋,鞋子或拖鞋。 监督学习必须包括已经用您要预测的目标(或答案)标记的数据; 这是机器学习的方式。
鞋类分类模型的数据源是德克萨斯大学奥斯汀分校提供的UTZappos50k数据集,可免费用于学术,非商业用途。 鞋子数据集包含从Zappos.com收集的50,025张带标签的目录图像。
鞋类数据保存在本地,并使用DJL的ImageFolder数据集加载,该数据集从本地文件夹中检索图像。
// identify the location of the training data
String trainingDatasetRoot = "src/test/resources/imagefolder/train";
// identify the location of the validation data
String validateDatasetRoot = "src/test/resources/imagefolder/validate";
//create training data ImageFolder dataset
ImageFolder trainingDataset = initDataset(trainingDatasetRoot);
//create validation data ImageFolder dataset
ImageFolder validateDataset = initDataset(validateDatasetRoot);
在本地构建数据时,我并没有深入到UTZappos50k数据集所确定的最细粒度的级别,例如脚踝,膝盖高,小腿中部,膝盖上方等。靴子的分类标签。 我的本地数据处于最高分类级别,其中仅包括靴子,凉鞋,鞋子和拖鞋。
用DJL术语来说,数据集只是保存训练数据。 有一些数据集实现可用于下载数据(基于您提供的URL),提取数据并自动将数据分为训练和验证集。
自动分离是一个有用的功能,因为从不使用与训练模型相同的数据来验证模型的性能非常重要。 模型使用训练数据集来查找鞋类数据中的趋势和模式。 验证数据集用于通过对鞋类进行分类的模型准确性的无偏估计来限定模型的性能。
如果使用与训练时相同的数据对模型进行验证,则我们对模型进行鞋分类的信心将会大大降低,因为正在使用已经看到的数据对模型进行测试。 在现实世界中,老师不会使用学习指南中提供的完全相同的问题来测试学生,因为这无法衡量学生对材料的真实了解或理解; 随后,机器学习模型也适用相同的概念。
现在,我们已将鞋类数据分为训练和验证集,现在让我们使用神经网络来训练(或产生)模型。
public final class Training extends AbstractTraining {
. . .
@Override
protected void train(Arguments arguments) throws IOException {
// identify the location of the training data
String trainingDatasetRoot = "src/test/resources/imagefolder/train";
// identify the location of the validation data
String validateDatasetRoot = "src/test/resources/imagefolder/validate";
//create training data ImageFolder dataset
ImageFolder trainingDataset = initDataset(trainingDatasetRoot);
//create validation data ImageFolder dataset
ImageFolder validateDataset = initDataset(validateDatasetRoot);
. . .
try (Model model = Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH)) {
TrainingConfig config = setupTrainingConfig(loss);
try (Trainer trainer = model.newTrainer(config)) {
trainer.setMetrics(metrics);
trainer.setTrainingListener(this);
Shape inputShape = new Shape(1, 3, NEW_HEIGHT, NEW_WIDTH);
// initialize trainer with proper input shape
trainer.initialize(inputShape);
//find the patterns in data
fit(trainer, trainingDataset, validateDataset, "build/logs/training");
//set model properties
model.setProperty("Epoch", String.valueOf(EPOCHS));
model.setProperty("Accuracy", String.format("%.2f", getValidationAccuracy()));
// save the model after done training for inference later
//model saved as shoeclassifier-0000.params
model.save(Paths.get(modelParamsPath), modelParamsName);
}
}
}
}
第一步是通过调用Models.getModel(NUM_OF_OUTPUT,NEW_HEIGHT,NEW_WIDTH)获得模型实例。 深度学习是机器学习的一种形式,它使用神经网络来训练模型。 神经网络以人脑中的神经元为模型。 神经元是将信息(或数据)传输到其他细胞的细胞。
ResNet-50是经常用于图像分类的神经网络。 50表示原始输入数据和最终预测之间存在50层学习(或神经元)。 getModel()方法创建一个空模型,构造一个ResNet-50神经网络,并将该神经网络设置为模型。
public class Models {
public static ai.djl.Model getModel(int numOfOutput, int height, int width) {
//create new instance of an empty model
ai.djl.Model model = ai.djl.Model.newInstance();
//Block is a composable unit that forms a neural network; combine them
//like Lego blocks to form a complex network
Block resNet50 =
//construct the network
new ResNetV1.Builder()
.setImageShape(new Shape(3, height, width))
.setNumLayers(50)
.setOutSize(numOfOutput)
.build();
//set the neural network to the model
model.setBlock(resNet50);
return model;
}
}
下一步是通过调用model.newTrainer(config)方法来设置和配置Trainer。 通过调用setupTrainingConfig(loss)方法来初始化config对象,该方法设置训练配置(或超参数)以确定如何训练网络。
接下来的步骤允许我们通过设置以下内容来向Trainer添加功能:
trainer.setMetrics(metrics)
trainer.setTrainingListener( this )
的训练侦听器 trainer.initialize(inputShape)
正确的输入形状 Metrics
在培训期间收集并报告关键绩效指标(KPI),可用于分析和监视培训绩效和稳定性。 下一步是通过调用fit(trainer,trainingDataset,validateDataset, “ build / logs / training” )方法来开始训练过程,该方法将迭代训练数据并存储在模型中找到的模式。 训练结束时,将使用model.save(Paths.get(modelParamsPath),modelParamsName)方法在本地保存性能良好且经过验证的模型工件。
培训过程中报告的指标如下所示。 请注意,随着每个时代(或每个阶段),模型的准确性都会提高; 第9个阶段的最终训练准确性为90%。
现在我们已经生成了模型,可以将其用于对我们不知道其分类(或目标)的新数据执行推断(或预测)。
private Classifications predict() throws IOException, ModelException, TranslateException {
//the location to the model saved during training
String modelParamsPath = "build/logs";
//the name of the model set during training
String modelParamsName = "shoeclassifier";
// the path of image to classify
String imageFilePath = "src/test/resources/slippers.jpg";
//Load the image file from the path
BufferedImage img = BufferedImageUtils.fromFile(Paths.get(imageFilePath));
//holds the probability score per label
Classifications predictResult;
try (Model model = Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH)) {
//load the model
model.load(Paths.get(modelParamsPath), modelParamsName);
//define a translator for pre and post processing
Translator translator = new MyTranslator();
//run the inference using a Predictor
try (Predictor predictor = model.newPredictor(translator)) {
predictResult = predictor.predict(img);
}
}
return predictResult;
}
在设置了模型和要分类的图像的必要路径之后,请使用Models.getModel(NUM_OF_OUTPUT,NEW_HEIGHT,NEW_WIDTH)方法获取空的模型实例,然后使用model.load(Paths.get(modelParamsPath),modelParamsName对其进行初始化)方法。 这将加载上一步中训练的模型。
接下来,使用model.newPredictor(translator)方法使用指定的Translator初始化Predictor 。 用DJL术语来说, 翻译器提供模型预处理和后处理功能。 例如,对于CV模型,需要将图像重塑为灰度。 译者可以做到这一点。 Predictor允许我们使用predictor.predict(img)方法对加载的模型进行推断,并传入图像进行分类。
此示例显示了单个预测,但DJL还支持批量预测。 推论存储在predictResult中 ,其中包含每个标签的概率估计。
推断(每个图像)及其对应的概率得分如下所示。
图片 | 机率分数 |
---|---|
[信息]-[ |
|
[信息]-[ |
|
[信息]-[ |
|
[信息]-[ |
DJL提供了本机Java开发经验,其功能与其他Java库一样。 这些API旨在指导开发人员以最佳实践完成深度学习任务。 在开始DJL之前,需要对ML生命周期有一个很好的了解。 如果您不熟悉ML,请阅读概述,或者从InfoQ的文章系列开始,这是针对软件开发人员的机器学习简介 。 了解了生命周期和常见的ML术语后,开发人员可以快速掌握DJL的API。
亚马逊已经开源了DJL,有关该工具包的更多详细信息可以在DJL网站和Java Library API Specification页面上找到。 可以查看鞋类分类模型的代码,以进一步探索示例。
翻译自: https://www.infoq.com/articles/djl-deep-learning-java/?topicPageSponsorship=c1246725-b0a7-43a6-9ef9-68102c8d48e1
数据库djl