DJL-Java开发者动手学深度学习之使用自己训练的模型进行图片分类预测

在我们上期文章中(文章请见《深度学习之图片分类》),我们使用MNIST数据集训练了自己的图片分类模型,并保存在build/model目录下。接下来,我们将使用上期训练的模型进行预测图片。

加载模型

private static Classifications predict() throws IOException, ModelException, TranslateException {
    Image img = ImageFactory.getInstance().fromUrl("https://www.d2lcoder.com/image/0.png");
    Mlp mlp = new Mlp(
        Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH,
        Mnist.NUM_CLASSES,
        new int[] {128, 64});

    try (Model model = Model.newInstance("mlp")) {
        model.setBlock(mlp);

        Path modelDir = Paths.get("build/model");
        model.load(modelDir);

        List<String> classes =
            IntStream.range(0, 10).mapToObj(String::valueOf).collect(Collectors.toList());
		
        Translator<Image, Classifications> translator =
            ImageClassificationTranslator.builder()
            .addTransform(new ToTensor())
            .optSynset(classes)
            .build();

        try (Predictor<Image, Classifications> predictor = model.newPredictor(translator)) {
            return predictor.predict(img);
        }
    }
}

预测

Classifications classifications = predict();
System.out.println(classifications);

预测结果

我们在控制台上,可以看到输出以下内容:

[
	class: "0", probability: 6.60253
	class: "6", probability: 2.95327
	class: "2", probability: 0.30343
	class: "7", probability: -4.2e-01
	class: "5", probability: -1.4e+00
]

由些可见,预测图片为0的概率最高,预测图片为5的概率最低。

关注公众号,解锁更多深度学习内容。

DJL-Java开发者动手学深度学习之使用自己训练的模型进行图片分类预测_第1张图片

你可能感兴趣的:(d2lcoder,DJL,Java开发者动手学习深度学习,分类,数据挖掘,人工智能)