Java使用pytorch模型进行数据推算

我的Java后台需要对数据进行分析,但找不到合适的方法,就准备用pytorch写个模型凑活着用。

使用的DJL调用pytorch引擎

Github:djl/README.md at master · deepjavalibrary/djl · GitHub

pom.xml中添加依赖:


   ai.djl.pytorch
   pytorch-engine
   0.15.0

注意version与pytorch版本有一个对应关系

PyTorch engine version PyTorch native library version
pytorch-engine:0.15.0 pytorch-native-auto: 1.8.1, 1.9.1, 1.10.0
pytorch-engine:0.14.0 pytorch-native-auto: 1.8.1, 1.9.0, 1.9.1
pytorch-engine:0.13.0 pytorch-native-auto:1.9.0
pytorch-engine:0.12.0 pytorch-native-auto:1.8.1
pytorch-engine:0.11.0 pytorch-native-auto:1.8.1
pytorch-engine:0.10.0 pytorch-native-auto:1.7.1
pytorch-engine:0.9.0 pytorch-native-auto:1.7.0
pytorch-engine:0.8.0 pytorch-native-auto:1.6.0
pytorch-engine:0.7.0 pytorch-native-auto:1.6.0
pytorch-engine:0.6.0 pytorch-native-auto:1.5.0
pytorch-engine:0.5.0 pytorch-native-auto:1.4.0
pytorch-engine:0.4.0 pytorch-native-auto:1.4.0

其他问题访问连接:PyTorch Engine - Deep Java Library


官方给出了一个图片分类的例子,我只需要纯数据不需要图片输入。

随便写了个例子 输入是[a, b] 输出一个0~1的数

还是建议用python先训练好模型,不要用Java训练。模型训练好后,首先要做的是把pytorch模型转为TorchScript,TorchScript会把模型结构和参数都加载进去的

官网原文:

There are two ways to convert your model to TorchScript: tracing and scripting. We will only demonstrate the first one, tracing, but you can find information about scripting from the PyTorch documentation. When tracing, we use an example input to record the actions taken and capture the the model architecture. This works best when your model doesn't have control flow. If you do have control flow, you will need to use the scripting approach. In DJL, we use tracing to create TorchScript for our ModelZoo models.

Here is an example of tracing in actions:

import torch
import torchvision

# An instance of your model.
model = torchvision.models.resnet18(pretrained=True)

# Switch the model to eval model
model.eval()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

# Save the TorchScript model
traced_script_module.save("traced_resnet_model.pt")

如果你使用了dropout等 一定要记得加上model.eval()再保存

对于我的来说 就下面这样

model = LinearModel()

model.load_state_dict(torch.load("model.pth"))

input = torch.tensor([0.72, 0.94]).float() //根据你的模型随便创建一个输入
    
script = torch.jit.trace(model, input)
    
script.save("model.pt")

然后该写Java代码了

官网例子:Load a PyTorch Model - Deep Java Library

还有这个:03 image classification with your model - Deep Java Library

我的数据就不需要transform了 代码:

//首先创建一个模型
Model model = Model.newInstance("test");
        try {
            model.load(Paths.get("C:\\Users\\Administrator\\IdeaProjects\\PytorchInJava\\src\\main\\resources\\model.pt"));
            System.out.println(model);

            //Predictor<参数类型,返回值类型> 输入图片的话参数是Image
            //我的参数是float32 不要写成Double
            Predictor objectObjectPredictor = model.newPredictor(new NoBatchifyTranslator() {
                @Override
                public NDList processInput(TranslatorContext translatorContext, float[] input) throws Exception {
                    NDManager ndManager = translatorContext.getNDManager();
                    NDArray ndArray = ndManager.create(input);
                    //ndArray作为输入
                    System.out.println(ndArray);
                    return new NDList(ndArray);
                }
                @Override
                public Object processOutput(TranslatorContext translatorContext, NDList ndList) throws Exception {
                    System.out.println("process: " + ndList.get(0).getFloat());
                    return ndList.get(0).getFloat();
                }
            });

            float result = objectObjectPredictor.predict(new float[]{0.6144011f, 0.952401f});

            System.out.println("result: " + result);
        } catch (IOException e) {
            e.printStackTrace();
        } catch (MalformedModelException e) {
            e.printStackTrace();
        } catch (Exception e) {
            System.out.println("qunimade ");
            e.printStackTrace();
        }

输出:

Java使用pytorch模型进行数据推算_第1张图片

更新

当我打包成jar到centos7的linux中运行时,报错UnsatisfiedLinkError,经过大神的指导,问题出在我引的依赖。

修改后的依赖:

    
        8
        5.3.0
    


    
        
            ai.djl.pytorch
            pytorch-engine
            0.16.0
        
        
            ai.djl.pytorch
            pytorch-native-cpu-precxx11
            linux-x86_64
            1.9.1
            runtime
        
        
            ai.djl.pytorch
            pytorch-jni
            1.9.1-0.16.0
            runtime
        
        
            org.springframework.boot
            spring-boot-starter-web
        
    

你可能感兴趣的:(pytorch,java,人工智能)