Pytorch-Mobile-Android(2)


Android-Pytorch:QUICKSTART WITH A HELLOWORLD EXAMPLE(官网例1)

1.模型构成:

是一个resnet18模型(model.ptl),用来识别静态图片,图片和模型都存放在了assets目录下。

Pytorch-Mobile-Android(2)_第1张图片

2.Gradle Dependencies:

Pytorch-Mobile-Android(2)_第2张图片

上一篇文章说过,build.gradle是一个配置构建文件,其中dependencies可以理解为插件加载区:implementation是远程依赖声明,意味着如本地没有所提示的插件,那就将声明的插件从别地下载到本地。

3.开始阅读代码

首先,我们要保存模型存入asset文件夹,就需要以下代码。

需要说明的是:traced_script_module_optimized._save_for_lite_interpreter("app/src/main/assets/model.ptl")这里的地址只能写同目录下的地址,否则会报错file can not be openned。

import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile

model = torchvision.models.mobilenet_v2(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter("app/src/main/assets/model.ptl")

这里对torch.jit.trace做一个说明:

【Pytorch部署】TorchScript - 知乎 (zhihu.com)icon-default.png?t=L9C2https://zhuanlan.zhihu.com/p/135911580

JIT = just in time compilation即时编译是一种程序优化,目的是:

TorchScript

动态图模型通过牺牲一些高级特性来换取易用性,那到底 JIT 有哪些特性,在什么情况下不得不用到 JIT 呢?下面主要通过介绍 TorchScript(PyTorch 的 JIT 实现)来分析 JIT 到底带来了哪些好处。

  1. 模型部署

PyTorch 的 1.0 版本发布的最核心的两个新特性就是 JIT 和 C++ API,这两个特性一起发布不是没有道理的,JIT 是 Python 和 C++ 的桥梁,我们可以使用 Python 训练模型,然后通过 JIT 将模型转为语言无关的模块,从而让 C++ 可以非常方便得调用,从此「使用 Python 训练模型,使用 C++ 将模型部署到生产环境」对 PyTorch 来说成为了一件很容易的事。而因为使用了 C++,我们现在几乎可以把 PyTorch 模型部署到任意平台和设备上:树莓派、iOS、Android 等等…

2. 性能提升

既然是为部署生产所提供的特性,那免不了在性能上面做了极大的优化,如果推断的场景对性能要求高,则可以考虑将模型(torch.nn.Module)转换为 TorchScript Module,再进行推断。

3. 模型可视化

TensorFlow 或 Keras 对模型可视化工具(TensorBoard等)非常友好,因为本身就是静态图的编程模型,在模型定义好后整个模型的结构和正向逻辑就已经清楚了;但 PyTorch 本身是不支持的,所以 PyTorch 模型在可视化上一直表现得不好,但 JIT 改善了这一情况。现在可以使用 JIT 的 trace 功能来得到 PyTorch 模型针对某一输入的正向逻辑,通过正向逻辑可以得到模型大致的结构,但如果在 `forward` 方法中有很多条件控制语句,这依然不是一个好的方法,所以 PyTorch JIT 还提供了 Scripting 的方式,这两种方式在下文中将详细介绍。

记住,Android的所有逻辑都是从MainActivity开始的,我们不需要弄懂所有代码,AS只是一个工具,我们也不需要学会每行代码的语法结构,大概看懂代码,学会修改参数就可以了。

Pytorch-Mobile-Android(2)_第3张图片

将jpg图片转为位图:

Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));

加载模型:

Module module = Module.load(assetFilePath(this, "model.ptl"));

 准备输入:还记得我们在build.gradle下加载的库函数(插件)org.pytorch:pytorch_android_torchvision吗?org.pytorch.torchvision.TensorImageUtils是它的子库。

ensorImageUtils.bitmapToFloat32Tensor可以把安卓的位图,转化为torchvision可以接受的tensor形式。

tips:不知道大家是否还记得归一化,例如

Pytorch-Mobile-Android(2)_第4张图片

将数据限定在[0,1]内。

所有的预训练模型都要求图像的归一化是统一的,举个例子:

对于RGB三通道的图片(3 * H * W),H,W分别是位图的高和宽,对所有图片归一化,

mean = [0.485, 0.456, 0.406] and std(标准差) = [0.229, 0.224, 0.225] 需一致。

Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
    TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);

调用forward方法,并得到结果

org.pytorch.Module.forward method runs loaded module’s forward method and gets result as org.pytorch.Tensor outputTensor with shape 1x1000. 返回的是一个java数组

Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
float[] scores = outputTensor.getDataAsFloatArray();

找对应结果

根据数组内的最大数值去ImageNet里去找对应的预测出的类别。

也就是从ImageNetClasses.IMAGENET_CLASSES[index]去找对应的类别

float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
  if (scores[i] > maxScore) {
    maxScore = scores[i];
    maxScoreIdx = i;
  }
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];

4.所用安卓类总结

Pytorch-Mobile-Android(2)_第5张图片

 5.所用pytorch类总结

Pytorch-Mobile-Android(2)_第6张图片

6.CNN模型部署pytorch mobile总结:

import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile

model = torchvision.models.mobilenet_v2(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter("app/src/main/assets/model.ptl")

将上述代码中的torchvision.models.模型一换,再存入assets里面即可

你可能感兴趣的:(android,pytorch)