DJL Java环境下部署pytorch模型推理

由于大数据基本都是Java环境,希望与深度学习结合的话,需要将深度学习模型部署在Java环境下。传统方式使用flask搭建接口,在Java环境中对其调用,但通信时间和内存问题限制了这种方式的发展。

DJL是采用Java编写的深度学习框架,支持MXnet,Tensorflow,Pytorch引擎,这意味着同一个模型采用不同语言编写,在DJL框架中运行只需要更改依赖,代码完全一样即可执行。关于DJL更多的介绍大家可以浏览DJL官网,知乎,以及b站的课程。

知乎专栏:DJL深度学习库 - 知乎

b站课程录播:深度学习兽的个人空间_哔哩哔哩_Bilibili 

GitHub:DeepJavaLibrary · GitHub 

下面介绍部署pytorch模型步骤以及我个人遇到的一些坑,希望对大家有所帮助

首先是pom文件依赖

import torch
print(torch.__version__)

 首先使用该命令查看本地环境下的pytorch版本,根据本地的pytorch版本,选取合适的engine

PyTorch Engine - Deep Java Library

这是DJL官网的例子,也包含Linux和maxOS下的依赖配置,我的pytorch版本是1.9.0,给出我的pom文件做参考

 
        
        
            ai.djl
            api
            0.18.0
        

        
            ai.djl.pytorch
            pytorch-engine
            0.18.0
            runtime
        
        
            ai.djl.pytorch
            pytorch-native-auto
            1.9.1
            runtime
        

        
            ai.djl.pytorch
            pytorch-native-cpu
            win-x86_64
            runtime
            1.11.0
        
        
            ai.djl.pytorch
            pytorch-jni
            1.11.0-0.18.0
            runtime
        

    

 对于加载自己的本地模型,踩到的两个坑,第一个就是如果该模型是用GPU训练的,那么之后推理也需要使用GPU,如果想用CPU推理,那就需要用CPU训练网络(这一条我不确定是否正确,只是我这样修改后确实没有报错了)第二个坑就是在python中保存模型时,要使用下面的代码

net.eval()

input = np.random.uniform(0, 1, (1,1, 2048, 1))
input = input.astype(np.float32)
input = torch.from_numpy(input)
script = torch.jit.trace(net, input)
script.save(save_path+"/"+"0726+"+str(test_acc)+".pt")

使用script.save保存模型,之前我的代码是torch.save,保存的模型在DJL中加载会报错

DJL加载model首先获取本地模型的url

Path modeldir = Paths.get("D:\\1.pt");

之后重写Translator,这个需要自定义模型的输入输出类型

Translator translator = new NoBatchifyTranslator() {
            @Override
            public NDList processInput(TranslatorContext translatorContext, NDList inputs) throws Exception {
                NDManager ndManager = translatorContext.getNDManager();
                NDArray ndArray = ndManager.create(new float[2048]).reshape(1,1,2048,1);
                //ndArray作为输入
                System.out.println(ndArray);
                return new NDList(ndArray);
            }
            @Override
            public Long processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
                System.out.println("process: " + ndList);
                System.out.println("process-1:" + ndList.get(0));
                System.out.println("process-2:" + ndList.get(0).argMax());
                NDArray tmp = ndList.get(0).argMax();
                Long label =  tmp.max().getLong();
                return  label;
            }

        };

这个输入还是有问题的,传入的NDList完全没用上,一直在定义新的ndArray

translator完成后,调用Criteria,加载模型

        Criteria criteria = Criteria.builder()
                .setTypes(NDList.class,Long.class)
                .optModelPath(modeldir)
                .optTranslator(translator)
                .build();

之后调用predictor,生成预测器 

Predictor predictor = criteria.loadModel().newPredictor();

创建样本,测试样本输出(由于translator的问题,这里传什么进去结果都一样)

NDManager manager = NDManager.newBaseManager();

NDArray array = manager.randomUniform(0, 1, new Shape(2048));

NDList testarray = new NDList(array);

Long result = predictor.predict(testarray);

System.out.println("result:" + result);

 

 

 

 

 

你可能感兴趣的:(java,pytorch,深度学习)