是一个resnet18模型(model.ptl),用来识别静态图片,图片和模型都存放在了assets目录下。
上一篇文章说过,build.gradle是一个配置构建文件,其中dependencies可以理解为插件加载区:implementation是远程依赖声明,意味着如本地没有所提示的插件,那就将声明的插件从别地下载到本地。
首先,我们要保存模型存入asset文件夹,就需要以下代码。
需要说明的是:traced_script_module_optimized._save_for_lite_interpreter("app/src/main/assets/model.ptl")这里的地址只能写同目录下的地址,否则会报错file can not be openned。
import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile
model = torchvision.models.mobilenet_v2(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter("app/src/main/assets/model.ptl")
这里对torch.jit.trace做一个说明:
【Pytorch部署】TorchScript - 知乎 (zhihu.com)https://zhuanlan.zhihu.com/p/135911580
JIT = just in time compilation即时编译是一种程序优化,目的是:
TorchScript
动态图模型通过牺牲一些高级特性来换取易用性,那到底 JIT 有哪些特性,在什么情况下不得不用到 JIT 呢?下面主要通过介绍 TorchScript(PyTorch 的 JIT 实现)来分析 JIT 到底带来了哪些好处。
- 模型部署
PyTorch 的 1.0 版本发布的最核心的两个新特性就是 JIT 和 C++ API,这两个特性一起发布不是没有道理的,JIT 是 Python 和 C++ 的桥梁,我们可以使用 Python 训练模型,然后通过 JIT 将模型转为语言无关的模块,从而让 C++ 可以非常方便得调用,从此「使用 Python 训练模型,使用 C++ 将模型部署到生产环境」对 PyTorch 来说成为了一件很容易的事。而因为使用了 C++,我们现在几乎可以把 PyTorch 模型部署到任意平台和设备上:树莓派、iOS、Android 等等…
2. 性能提升
既然是为部署生产所提供的特性,那免不了在性能上面做了极大的优化,如果推断的场景对性能要求高,则可以考虑将模型(torch.nn.Module)转换为 TorchScript Module,再进行推断。
3. 模型可视化
TensorFlow 或 Keras 对模型可视化工具(TensorBoard等)非常友好,因为本身就是静态图的编程模型,在模型定义好后整个模型的结构和正向逻辑就已经清楚了;但 PyTorch 本身是不支持的,所以 PyTorch 模型在可视化上一直表现得不好,但 JIT 改善了这一情况。现在可以使用 JIT 的 trace 功能来得到 PyTorch 模型针对某一输入的正向逻辑,通过正向逻辑可以得到模型大致的结构,但如果在 `forward` 方法中有很多条件控制语句,这依然不是一个好的方法,所以 PyTorch JIT 还提供了 Scripting 的方式,这两种方式在下文中将详细介绍。
记住,Android的所有逻辑都是从MainActivity开始的,我们不需要弄懂所有代码,AS只是一个工具,我们也不需要学会每行代码的语法结构,大概看懂代码,学会修改参数就可以了。
将jpg图片转为位图:
Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
加载模型:
Module module = Module.load(assetFilePath(this, "model.ptl"));
准备输入:还记得我们在build.gradle下加载的库函数(插件)org.pytorch:pytorch_android_torchvision吗?org.pytorch.torchvision.TensorImageUtils是它的子库。
ensorImageUtils.bitmapToFloat32Tensor可以把安卓的位图,转化为torchvision可以接受的tensor形式。
tips:不知道大家是否还记得归一化,例如
将数据限定在[0,1]内。
所有的预训练模型都要求图像的归一化是统一的,举个例子:
对于RGB三通道的图片(3 * H * W),H,W分别是位图的高和宽,对所有图片归一化,
mean = [0.485, 0.456, 0.406]
andstd(标准差) = [0.229, 0.224, 0.225] 需一致。
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
调用forward方法,并得到结果
org.pytorch.Module.forward
method runs loaded module’sforward
method and gets result asorg.pytorch.Tensor
outputTensor with shape1x1000
. 返回的是一个java数组
Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
float[] scores = outputTensor.getDataAsFloatArray();
找对应结果
根据数组内的最大数值去ImageNet里去找对应的预测出的类别。
也就是从ImageNetClasses.IMAGENET_CLASSES[index]去找对应的类别
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile
model = torchvision.models.mobilenet_v2(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module_optimized = optimize_for_mobile(traced_script_module)
traced_script_module_optimized._save_for_lite_interpreter("app/src/main/assets/model.ptl")
将上述代码中的torchvision.models.模型一换,再存入assets里面即可