由于大数据基本都是Java环境,希望与深度学习结合的话,需要将深度学习模型部署在Java环境下。传统方式使用flask搭建接口,在Java环境中对其调用,但通信时间和内存问题限制了这种方式的发展。
DJL是采用Java编写的深度学习框架,支持MXnet,Tensorflow,Pytorch引擎,这意味着同一个模型采用不同语言编写,在DJL框架中运行只需要更改依赖,代码完全一样即可执行。关于DJL更多的介绍大家可以浏览DJL官网,知乎,以及b站的课程。
知乎专栏:DJL深度学习库 - 知乎
b站课程录播:深度学习兽的个人空间_哔哩哔哩_Bilibili
GitHub:DeepJavaLibrary · GitHub
下面介绍部署pytorch模型步骤以及我个人遇到的一些坑,希望对大家有所帮助
首先是pom文件依赖
import torch
print(torch.__version__)
首先使用该命令查看本地环境下的pytorch版本,根据本地的pytorch版本,选取合适的engine
PyTorch Engine - Deep Java Library
这是DJL官网的例子,也包含Linux和maxOS下的依赖配置,我的pytorch版本是1.9.0,给出我的pom文件做参考
ai.djl
api
0.18.0
ai.djl.pytorch
pytorch-engine
0.18.0
runtime
ai.djl.pytorch
pytorch-native-auto
1.9.1
runtime
ai.djl.pytorch
pytorch-native-cpu
win-x86_64
runtime
1.11.0
ai.djl.pytorch
pytorch-jni
1.11.0-0.18.0
runtime
对于加载自己的本地模型,踩到的两个坑,第一个就是如果该模型是用GPU训练的,那么之后推理也需要使用GPU,如果想用CPU推理,那就需要用CPU训练网络(这一条我不确定是否正确,只是我这样修改后确实没有报错了)第二个坑就是在python中保存模型时,要使用下面的代码
net.eval()
input = np.random.uniform(0, 1, (1,1, 2048, 1))
input = input.astype(np.float32)
input = torch.from_numpy(input)
script = torch.jit.trace(net, input)
script.save(save_path+"/"+"0726+"+str(test_acc)+".pt")
使用script.save保存模型,之前我的代码是torch.save,保存的模型在DJL中加载会报错
DJL加载model首先获取本地模型的url
Path modeldir = Paths.get("D:\\1.pt");
之后重写Translator,这个需要自定义模型的输入输出类型
Translator translator = new NoBatchifyTranslator() {
@Override
public NDList processInput(TranslatorContext translatorContext, NDList inputs) throws Exception {
NDManager ndManager = translatorContext.getNDManager();
NDArray ndArray = ndManager.create(new float[2048]).reshape(1,1,2048,1);
//ndArray作为输入
System.out.println(ndArray);
return new NDList(ndArray);
}
@Override
public Long processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
System.out.println("process: " + ndList);
System.out.println("process-1:" + ndList.get(0));
System.out.println("process-2:" + ndList.get(0).argMax());
NDArray tmp = ndList.get(0).argMax();
Long label = tmp.max().getLong();
return label;
}
};
这个输入还是有问题的,传入的NDList完全没用上,一直在定义新的ndArray
translator完成后,调用Criteria,加载模型
Criteria criteria = Criteria.builder()
.setTypes(NDList.class,Long.class)
.optModelPath(modeldir)
.optTranslator(translator)
.build();
之后调用predictor,生成预测器
Predictor predictor = criteria.loadModel().newPredictor();
创建样本,测试样本输出(由于translator的问题,这里传什么进去结果都一样)
NDManager manager = NDManager.newBaseManager();
NDArray array = manager.randomUniform(0, 1, new Shape(2048));
NDList testarray = new NDList(array);
Long result = predictor.predict(testarray);
System.out.println("result:" + result);