TensorFlow首次快速体验

TensorFlow是深度学习中使用人数最多的框架,本文快速尝试一下其能力,方便入门

添加依赖


    org.tensorflow
    tensorflow
    1.13.1

定义图模型

示例完成一个简单的函数:

f(x, y) = z = a*x + b*y

其中a, b是常量,x, y是变量

  • 定义Graph
Graph graph = new Graph()
  • 定义常量
Operation a = graph.opBuilder("Const", "a")
        .setAttr("dtype", DataType.fromClass(Double.class))
        .setAttr("value", Tensor.create(3.0, Double.class))
        .build();
Operation b = graph.opBuilder("Const", "b")
        .setAttr("dtype", DataType.fromClass(Double.class))
        .setAttr("value", Tensor.create(2.0, Double.class))
        .build()
  • 定义变量
Operation x = graph.opBuilder("Placeholder", "x")
        .setAttr("dtype", DataType.fromClass(Double.class))
        .build();
Operation y = graph.opBuilder("Placeholder", "y")
        .setAttr("dtype", DataType.fromClass(Double.class))
        .build();
  • 定义函数
Operation ax = graph.opBuilder("Mul", "ax")
        .addInput(a.output(0))
        .addInput(x.output(0))
        .build();
Operation by = graph.opBuilder("Mul", "by")
        .addInput(b.output(0))
        .addInput(y.output(0))
        .build();
Operation z = graph.opBuilder("Add", "z")
        .addInput(ax.output(0))
        .addInput(by.output(0))
        .build();

可以看出来,用Java定义图模型比较麻烦,但是使用Python会简单很多

执行

Session session = new Session(graph);
Tensor tensor = session.runner().fetch("z")
        .feed("x", Tensor.create(3.0, Double.class))
        .feed("y", Tensor.create(6.0, Double.class))
        .run().get(0).expect(Double.class);
System.out.println(tensor.doubleValue());

图模型保存及加载

  • 保存模型
Path path = Paths.get("tensor.model");
byte[] bytes = graph.toGraphDef();
Files.write(path, bytes);
  • 加载模型
Graph graph = new Graph();
byte[] bytes = Files.readAllBytes(path);
graph.importGraphDef(bytes);

ps: 模型可以在不同语言通用,所以可以使用python训练模型,然后提供给其他语言使用,比如Java

结果

最后输出结果:21.0

参考

  • TensorFlow官网
  • Introduction to Tensorflow for Java
  • 安装 Java 版 TensorFlow

你可能感兴趣的:(TensorFlow首次快速体验)