我的Java后台需要对数据进行分析,但找不到合适的方法,就准备用pytorch写个模型凑活着用。
使用的DJL调用pytorch引擎
Github:djl/README.md at master · deepjavalibrary/djl · GitHub
pom.xml中添加依赖:
ai.djl.pytorch
pytorch-engine
0.15.0
注意version与pytorch版本有一个对应关系
PyTorch engine version | PyTorch native library version |
---|---|
pytorch-engine:0.15.0 | pytorch-native-auto: 1.8.1, 1.9.1, 1.10.0 |
pytorch-engine:0.14.0 | pytorch-native-auto: 1.8.1, 1.9.0, 1.9.1 |
pytorch-engine:0.13.0 | pytorch-native-auto:1.9.0 |
pytorch-engine:0.12.0 | pytorch-native-auto:1.8.1 |
pytorch-engine:0.11.0 | pytorch-native-auto:1.8.1 |
pytorch-engine:0.10.0 | pytorch-native-auto:1.7.1 |
pytorch-engine:0.9.0 | pytorch-native-auto:1.7.0 |
pytorch-engine:0.8.0 | pytorch-native-auto:1.6.0 |
pytorch-engine:0.7.0 | pytorch-native-auto:1.6.0 |
pytorch-engine:0.6.0 | pytorch-native-auto:1.5.0 |
pytorch-engine:0.5.0 | pytorch-native-auto:1.4.0 |
pytorch-engine:0.4.0 | pytorch-native-auto:1.4.0 |
其他问题访问连接:PyTorch Engine - Deep Java Library
官方给出了一个图片分类的例子,我只需要纯数据不需要图片输入。
随便写了个例子 输入是[a, b] 输出一个0~1的数
还是建议用python先训练好模型,不要用Java训练。模型训练好后,首先要做的是把pytorch模型转为TorchScript,TorchScript会把模型结构和参数都加载进去的
官网原文:
There are two ways to convert your model to TorchScript: tracing and scripting. We will only demonstrate the first one, tracing, but you can find information about scripting from the PyTorch documentation. When tracing, we use an example input to record the actions taken and capture the the model architecture. This works best when your model doesn't have control flow. If you do have control flow, you will need to use the scripting approach. In DJL, we use tracing to create TorchScript for our ModelZoo models.
Here is an example of tracing in actions:
import torch
import torchvision
# An instance of your model.
model = torchvision.models.resnet18(pretrained=True)
# Switch the model to eval model
model.eval()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
# Save the TorchScript model
traced_script_module.save("traced_resnet_model.pt")
如果你使用了dropout等 一定要记得加上model.eval()再保存
对于我的来说 就下面这样
model = LinearModel()
model.load_state_dict(torch.load("model.pth"))
input = torch.tensor([0.72, 0.94]).float() //根据你的模型随便创建一个输入
script = torch.jit.trace(model, input)
script.save("model.pt")
然后该写Java代码了
官网例子:Load a PyTorch Model - Deep Java Library
还有这个:03 image classification with your model - Deep Java Library
我的数据就不需要transform了 代码:
//首先创建一个模型
Model model = Model.newInstance("test");
try {
model.load(Paths.get("C:\\Users\\Administrator\\IdeaProjects\\PytorchInJava\\src\\main\\resources\\model.pt"));
System.out.println(model);
//Predictor<参数类型,返回值类型> 输入图片的话参数是Image
//我的参数是float32 不要写成Double
Predictor objectObjectPredictor = model.newPredictor(new NoBatchifyTranslator() {
@Override
public NDList processInput(TranslatorContext translatorContext, float[] input) throws Exception {
NDManager ndManager = translatorContext.getNDManager();
NDArray ndArray = ndManager.create(input);
//ndArray作为输入
System.out.println(ndArray);
return new NDList(ndArray);
}
@Override
public Object processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
System.out.println("process: " + ndList.get(0).getFloat());
return ndList.get(0).getFloat();
}
});
float result = objectObjectPredictor.predict(new float[]{0.6144011f, 0.952401f});
System.out.println("result: " + result);
} catch (IOException e) {
e.printStackTrace();
} catch (MalformedModelException e) {
e.printStackTrace();
} catch (Exception e) {
System.out.println("qunimade ");
e.printStackTrace();
}
输出:
当我打包成jar到centos7的linux中运行时,报错UnsatisfiedLinkError,经过大神的指导,问题出在我引的依赖。
修改后的依赖:
8
5.3.0
ai.djl.pytorch
pytorch-engine
0.16.0
ai.djl.pytorch
pytorch-native-cpu-precxx11
linux-x86_64
1.9.1
runtime
ai.djl.pytorch
pytorch-jni
1.9.1-0.16.0
runtime
org.springframework.boot
spring-boot-starter-web