使用DJL库部署深度学习模型

文章目录

    • DJL库
    • Spring Boot 微服务集成 DJL
      • 引入 djl-spring-boot-starter 依赖
      • 下载 resnet18 模型
      • 定义预测工具类
      • 写预测接口
      • 使用 Postman 测试

DJL库

DJL(Deep Java Library) 是一个使用Java API简化模型训练、测试、部署和使用深度学习模型进行推理的开源库深度学习工具包,开源的许可协议是Apache-2.0。

对于Java开发者而言,可以在Java中开发及应用原生的机器学习和深度学习模型,同时简化了深度学习开发的难度。

通过DJL提供的直观的、高级的API,Java开发人员可以训练自己的模型,或者利用数据科学家用Python预先训练好的模型来进行推理。

Spring Boot 微服务集成 DJL

这里选择 Kotlin + Gradle + Spring Boot 搭建项目,引入 sprint-boot-starter-web 依赖

引入 djl-spring-boot-starter 依赖

dependencies {
    implementation("org.springframework.boot:spring-boot-starter-web")

    implementation("ai.djl.spring:djl-spring-boot-starter-pytorch-auto:0.15")
    implementation("ai.djl.spring:djl-spring-boot-starter-autoconfigure:0.15")
    implementation("net.java.dev.jna:jna:5.11.0")
    implementation("org.slf4j:slf4j-api:1.7.36")

    implementation("org.jetbrains.kotlin:kotlin-reflect")
    implementation("org.jetbrains.kotlin:kotlin-stdlib-jdk8")
    testImplementation("org.springframework.boot:spring-boot-starter-test")
}

下载 resnet18 模型

DownloadUtils.download(
    "https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz",
    "src/main/resources/models/resnet18/resnet18.pt", ProgressBar())
DownloadUtils.download(
    "https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt",
    "src/main/resources/models/resnet18/synset.txt", ProgressBar())

模型文件会下载到 src/main/resources/models/resnet18 路径下

定义预测工具类

// util/PredictUtil.kt
@Component
class PredictUtil {

    private fun createPipeline(): Pipeline {
        val pipeline = Pipeline()
        pipeline.add(Resize(224, 224))
                .add(ToTensor())
                .add(Normalize(
                    floatArrayOf(0.485f, 0.456f, 0.406f),
                    floatArrayOf(0.229f, 0.224f, 0.225f)))
        return pipeline
    }

    private fun createTranslator(pipeline: Pipeline): Translator<Image, Classifications> {
        return ImageClassificationTranslator.builder()
            .setPipeline(pipeline)
            .optApplySoftmax(true)
            .build()
    }

    private fun loadModel(): ZooModel<Image, Classifications> {
        val translator = createTranslator(createPipeline())
        val criteria: Criteria<Image, Classifications> = Criteria.builder()
            .setTypes(Image::class.java, Classifications::class.java)
            .optModelPath(Paths.get("src/main/resources/models/resnet18"))
            .optTranslator(translator)
            .optProgress(ProgressBar()).build()
        return criteria.loadModel()
    }

    fun predict(file: File): String {
        val model: ZooModel<Image, Classifications> = loadModel()
        val image = ImageFactory.getInstance().fromInputStream(FileInputStream(file))
        val predictor: Predictor<Image, Classifications> = model.newPredictor()
        val classifications = predictor.predict(image)
        println(classifications)
        return classifications.toString()
    }

}

写预测接口

// controller/PredictController.kt
@RestController
@RequestMapping(path = ["/predict"])
class PredictController {

    @Resource
    private val predictUtil = PredictUtil()

    @RequestMapping(method = [RequestMethod.POST], path = ["/"])
    fun predict(@RequestBody file: MultipartFile): String {
//        DownloadUtils.download(
//            "https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/resnet/0.0.1/traced_resnet18.pt.gz",
//            "src/main/resources/models/resnet18/resnet18.pt", ProgressBar())
//        DownloadUtils.download(
//            "https://djl-ai.s3.amazonaws.com/mlrepo/model/cv/image_classification/ai/djl/pytorch/synset.txt",
//            "src/main/resources/models/resnet18/synset.txt", ProgressBar())

        val tmp = File.createTempFile("temp", null)
        file.transferTo(tmp);
        return predictUtil.predict(tmp);
    }

}

使用 Postman 测试

http://127.0.0.1:8080/predict/ 发送 POST 请求

成功运行

你可能感兴趣的:(深度学习,人工智能)