【java】【31】java调用pytorch

1.下载解压1.4或者1.6版本

mkdir -p /opt/pytorch && cd /opt/pytorch &&  wget https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.4.0%2Bcpu.zip
unzip  libtorch-shared-with-deps-1.4.0+cpu.zip


https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.6.0%2Bcpu.zip

2.解压

cd /opt/pytorch && unzip libtorch-shared-with-deps-1.4.0+cpu.zip

3.开发

1.引入pom包


	org.pytorch
	pytorch_java_only
	1.4.0

2.开发javaDemo,新建springboot项目

/**
 * 
 * https://pytorch.org/get-started/locally/
 * https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.4.0%2Bcpu.zip
 
   https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html
   https://pytorch.org/javadoc/
   https://github.com/dreiss/java-demo

   https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.4.0%2Bcpu.zip
 */


@SpringBootApplication
@RestController
public class App {

    @RequestMapping("ping")
    public String ping() {
        return "tong"+System.currentTimeMillis();
    }

	@RequestMapping("getDemoModel")
	public String getDemoModel() {
	     //模型在https://github.com/dreiss/java-demo下载,代码也是很这个demo一样的
		  Module mod = Module.load("/root/demo-model.pt1");
		  Tensor data =
			  Tensor.fromBlob(
				  new int[] {1, 2, 3, 4, 5, 6}, // data
				  new long[] {2, 3} // shape
				  );
		  
		  IValue result = mod.forward(IValue.from(data), IValue.from(3.0));
		  Tensor output = result.toTensor();
		  System.out.println("shape: " + Arrays.toString(output.shape()));
		  System.out.println("data: " + Arrays.toString(output.getDataAsFloatArray()));
		  return "data: " + Arrays.toString(output.getDataAsFloatArray()) + Arrays.toString(output.getDataAsFloatArray()) + "shape: " + Arrays.toString(output.shape());
	}
	
}	
Tensor是张量的意思,使用两个一维数组表示任意维度的数组
new long[] {2, 3} // shape表示形状,  {2,3}表示2行3列
new int[] {1, 2, 3, 4, 5, 6}, // data 根据shape可以表示一个两行三列的数组
Tensor data =
		  Tensor.fromBlob(
			  new int[] {1, 2, 3, 4, 5, 6}, // data
			  new long[] {2, 3} // shape
			  );

3.启动项目,只能在linux上使用,需要部署到linux启动

java  -Djava.library.path=/opt/pytorch/libtorch/lib -jar pytorchDemo-0.0.1-SNAPSHOT.jar --server.port=8888

4.访问

http://192.168.1.25:8888/getDemoModel

你可能感兴趣的:(java)