java运行pytorch模型

引入相关依赖

pom.xml文件中添加如下依赖

<dependency>
	<groupId>org.pytorchgroupId>
	<artifactId>pytorch_java_onlyartifactId>
	<version>1.7.0version>
dependency>

按路径加载模型

Module mod = Module.load("XXX/model.pt");

输入数据处理

以nlp任务为例,可以定义一个InputFeatures类,输入为文本,将其解析为以下数据。

int[][] inputIds = new int[batchSize][seqLen];
int[][] inputMask = new int[batchSize][seqLen];
int[][] segmentIds = new int[batchSize][seqLen];

同时定义一个toLongList展平函数,将上述二维数据重整为大小为batchSizeseqLen的一维数组。fromBlob函数输入的两个参数分别是展平的一维数组和原始大小,返回大小为batchSizeseqLen的tensor。

org.pytorch.Tensor tensor_input_ids = org.pytorch.Tensor.fromBlob
(toLongList(inputIds), new long[]{batchSize, seqLen});

模型预测

首先要了解一下IValue,这是pytorch在java中的基本运算单位。之后将tensor数据转换为IValue作为forward函数的输入。

IValue result = mod.forward(
IValue.from(tensor_input_ids), 
IValue.from(tensor_input_mask), 
IValue.from(tensor_segment_ids));

输出数据解析

后续的java代码不能直接处理IValue类型的数据,因此要将其解析为传统的数组类型。由于模型输出的结果包含多项,因此result是Tuple类型,从中取出一项并返回数组的示例如下。

float [] output = result.toTuple()[0].toTensor().getDataAsFloatArray();

你可能感兴趣的:(实习生涯,Java,java,pytorch)