Pytorch-Mobile-Android(3) 部署自己模型

一、例子:

1.用torch.jit.script转torchscript,不要用torch.jit.trace

理由见:【Pytorch部署】TorchScript - 知乎 (zhihu.com)https://zhuanlan.zhihu.com/p/135911580

import vision_transformer
from torch.utils.mobile_optimizer import optimize_for_mobile
import torch

model_vit = vision_transformer._create_vision_transformer('vit_tiny_patch16_384')
model_vit = model_vit.eval()
example = torch.rand(1, 3, 384, 384)
traced_script_module = torch.jit.script(model_vit, example)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter(r"D:\paper_code\02android-demo-app-master\HelloWorldApp\app\src\main\assets\vit2.pt")

会报错UserWarning: `optimize` is deprecated and has no effect. Use `with torch.jit.optimized_execution() instead
  warnings.warn(

不清楚原因,但是不影响运行。

2.将图像的width和height用PIL改成符合的输入

from PIL import Image

img = Image.open(r'D:\paper_code\02android-demo-app-master\HelloWorldApp\app\src\main\assets\image.jpg')
# img = img.resize((384, 384), Image.BILINEAR)

# img.save(r'D:\paper_code\02android-demo-app-master\HelloWorldApp\app\src\main\assets\image.jpg')
print(img.size)

3.套用pytorch-mobile官网的代码运行即可

package org.pytorch.helloworld;

import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.util.Log;
import android.widget.ImageView;
import android.widget.TextView;

import org.pytorch.IValue;
import org.pytorch.LiteModuleLoader;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import org.pytorch.MemoryFormat;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;

import androidx.appcompat.app.AppCompatActivity;

public class MainActivity extends AppCompatActivity {

  @Override
  protected void onCreate(Bundle savedInstanceState) {
    super.onCreate(savedInstanceState);
    setContentView(R.layout.activity_main);

    Bitmap bitmap = null;
    Module module = null;
    try {
      // creating bitmap from packaged into app android asset 'image.jpg',
      // app/src/main/assets/image.jpg
      bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
      int width = bitmap.getWidth();
      int height = bitmap.getHeight();
      Log.e("width", String.format("width %d ", width)); //总时间
      Log.e("height", String.format("height %d ", height));
      // loading serialized torchscript module from packaged into app android asset model.pt,
      // app/src/model/assets/model.pt
      module = LiteModuleLoader.load(assetFilePath(this, "vit2.pt"));
    } catch (IOException e) {
      Log.e("PytorchHelloWorld", "Error reading assets", e);
      finish();
    }


    // showing image on UI
    ImageView imageView = findViewById(R.id.image);
    imageView.setImageBitmap(bitmap);

    // preparing input tensor
    final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
        TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);





    long startTime = System.currentTimeMillis(); //起始时间

    for (int x=0; x < 100;x = x+1 ) {

      // running the model
      final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();

      // getting tensor content as java array of floats
      final float[] scores = outputTensor.getDataAsFloatArray();

      // searching for the index with maximum score
      float maxScore = -Float.MAX_VALUE;
      int maxScoreIdx = -1;
      for (int i = 0; i < scores.length; i++) {
        if (scores[i] > maxScore) {
            maxScore = scores[i];
            maxScoreIdx = i;
          }
      }

      String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
      TextView textView = findViewById(R.id.text);
      textView.setText(className);
    }

    long endTime = System.currentTimeMillis(); //结束时间
    long runTime = endTime - startTime;
    System.out.println(runTime/100.0);//一次的平均时间
    Log.e("test", String.format("方法使用时间 %d ms", runTime)); //总时间

    // showing className on UI


  }

  /**
   * Copies specified asset to the file in /files app directory and returns this file absolute path.
   *
   * @return absolute file path
   */

  public static String assetFilePath(Context context, String assetName) throws IOException {
    File file = new File(context.getFilesDir(), assetName);
    if (file.exists() && file.length() > 0) {
      return file.getAbsolutePath();
    }

    try (InputStream is = context.getAssets().open(assetName)) {
      try (OutputStream os = new FileOutputStream(file)) {
        byte[] buffer = new byte[4 * 1024];
        int read;
        while ((read = is.read(buffer)) != -1) {
          os.write(buffer, 0, read);
        }
        os.flush();
      }
      return file.getAbsolutePath();
    }
  }
}

你可能感兴趣的:(pytorch,android,人工智能)