目前,Tensorflow的Java版本支持Windows、Mac OS、Linux、Android这几个操作系统。本次主要以Windows操作系统为列来介绍。**
在Windows操作系统中,如果要在Java语言中调用TensorFlow的模型,需要到TensorFlow官网的安装页面中下载一个TensorFlow的工具类包libtensorflow-1.5.0.jar,还有一个包含JNI接口的动态链接库文件压缩包libtensorflow_jni-cpu-windows-x86_64-1.5.0.zip,该压缩包展开后会得到TensorFlow_jni.dll动态链接库文件、注意,文件名中的版本号部分可能随着TensorFlow的升级而有所变化,在使用java程序调用神经网络模型的时候,这里文件都会用到。
下面是调用代码 保存的模型文件来进行预测的示例代码。
TestTF.java
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.SaveModelBundle;
import java.nio.FloatBuffer;
import java.util.Arrays;
public class TestTf {
public static void main(String[] args) {
SaveModelBudle smb =SaveModelBundle.load("export", "tag");
Session s=smb.session();
float[][] matrix={{1.0F,2.0F,3.0F,4.0F}};
System.out.println(Arrays.deepToString(matrix));
Tensor xFeed=Tensor.create(matrix);
Tensor result=s.runner.feed("x",xFeed).fetch("y").run().get(0);
FloatBuffer buf =FloatBuffer.allocate(2);
result.writeTo(buf);
System.out.println(result.toString());
System.out.println(buf.get(0));
System.out.println(buf.get(1));
}
}
主要说明的是一下几点: