android设备模型,Pytroch模型部署到Android设备

本项目是一个简单的图像分类应用程序,演示了如何使用PyTorch Android API。此应用程序在静态图像上运行TorchScript序列化的TorchVision预训练的resnet18模型,该模型作为Android资产打包在应用程序内部。

1.模型准备

让我们从模型准备开始。如果您熟悉PyTorch,您可能应该已经知道如何训练和保存模型。如果您不这样做,我们将使用预先训练的图像分类模型(Resnet18),该模型包装在TorchVision中。要安装它,请运行以下命令:

pip install torchvision

要序列化模型,可以在HelloWorld应用的根文件夹中使用python 代码:

import torch

import torchvision

model = torchvision.models.resnet18(pretrained=True)

model.eval()

example = torch.rand(1, 3, 224, 224)

traced_script_module = torch.jit.trace(model, example)

traced_script_module.save("app/src/main/assets/model.pt")

如果一切正常,我们应该拥有我们的模型- model.pt在android应用程序的Assets文件夹中生成。它将被打包为android应用程序内部,asset并且可以在设备上使用。

2.从github克隆

git clone https://github.com/pytorch/android-demo-app.gitcd HelloWorldApp

如果已经安装了Android SDK和Android NDK,则可以使用以下命令将此应用程序安装到连接的android设备或模拟器上:

./gradlew installDebug

我们建议您在Android Studio 3.5.1+中打开此项目。目前,PyTorch Android和演示应用程序使用版本3.5.0的android gradle插件,只有Android Studio版本3.5.1和更高版本才支持。使用Android Studio,您将能够通过Android Studio UI安装Android NDK和Android SDK。

3. Gradle依赖

Pytorch android作为build.gradle中的gradle依赖项添加到项目中:

repositories {

jcenter()

}

dependencies {

implementation 'org.pytorch:pytorch_android:1.4.0'

implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'

}

org.pytorch:pytorch_androidPyTorch Android API的主要依赖项在哪里,包括所有4个Android abis(armeabi-v7a,arm64-v8a,x86,x86_64)的libtorch本机库。此外,在此文档中,您可以找到如何仅针对特定的android abis列表重建它。

org.pytorch:pytorch_android_torchvision-具有实用功能的附加库,用于转换android.media.Image和android.graphics.Bitmap张量。

4.从Android Asset读取图像

所有逻辑都发生在中org.pytorch.helloworld.MainActivity。作为第一步,我们阅读image.jpg了android.graphics.Bitmap使用标准Android API的信息。

Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));

5.加载TorchScript模块

Module module = Module.load(assetFilePath(this, "model.pt"));

org.pytorch.Module表示torch::jit::script::Module可以使用load指定序列化到文件模型的文件路径的方法加载。

6.准备输入

Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,

TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);

org.pytorch.torchvision.TensorImageUtils是org.pytorch:pytorch_android_torchvision图书馆的一部分。该TensorImageUtils#bitmapToFloat32Tensor方法在创建张量torchvision格式使用android.graphics.Bitmap作为源。

所有经过预训练的模型都希望输入图像以相同的方式归一化,即形状为(3 x H x W)的3通道RGB图像的迷你批,其中H和W至少应为224。加载到的范围内[0, 1],然后使用mean = [0.485, 0.456, 0.406]和进行归一化std = [0.229, 0.224, 0.225]

inputTensor的形状为1x3xHxW,其中H和W分别是位图的高度和宽度。

7.运行推理

Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();

float[] scores = outputTensor.getDataAsFloatArray();

org.pytorch.Module.forward方法运行加载的模块的forward方法,并org.pytorch.Tensor使用shape 获得作为outputTensor的结果1x1000。

8.处理结果

使用以下org.pytorch.Tensor.getDataAsFloatArray()方法检索其内容:该方法返回浮点数的java数组,并为每个图像网络类分配分数。

之后,我们只找到具有最高分数的索引,然后从ImageNetClasses.IMAGENET_CLASSES包含所有ImageNet类的数组中检索预测的类名。

float maxScore = -Float.MAX_VALUE;

int maxScoreIdx = -1;

for (int i = 0; i < scores.length; i++) {

if (scores[i] > maxScore) {

maxScore = scores[i];

maxScoreIdx = i;

}

}

String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];

在以下各节中,您可以找到PyTorch Android API的详细说明,用于更大的演示应用程序的代码演练,API的实现细节,如何从源代码进行自定义和构建。

PYTORCH演示应用程序

我们还创建了另一个更复杂的PyTorch Android演示应用程序,该应用程序从同一github存储库中的摄像头输出和文本分类进行图像分类。

void setupCameraX() {

final PreviewConfig previewConfig = new PreviewConfig.Builder().build();

final Preview preview = new Preview(previewConfig);

preview.setOnPreviewOutputUpdateListener(output -> mTextureView.setSurfaceTexture(output.getSurfaceTexture()));

final ImageAnalysisConfig imageAnalysisConfig =

new ImageAnalysisConfig.Builder()

.setTargetResolution(new Size(224, 224))

.setCallbackHandler(mBackgroundHandler)

.setImageReaderMode(ImageAnalysis.ImageReaderMode.ACQUIRE_LATEST_IMAGE)

.build();

final ImageAnalysis imageAnalysis = new ImageAnalysis(imageAnalysisConfig);

imageAnalysis.setAnalyzer(

(image, rotationDegrees) -> {

analyzeImage(image, rotationDegrees);

});

CameraX.bindToLifecycle(this, preview, imageAnalysis);

}

void analyzeImage(android.media.Image, int rotationDegrees)

该analyzeImage方法处理相机输出的位置android.media.Image。

从模型中获得预测分数后,它会找到分数最高的前K个类别,并在用户界面上显示。

语言处理示例

另一个示例是基于LSTM模型的自然语言处理,并在reddit注释数据集上进行了训练。逻辑发生在中TextClassificattionActivity。

结果类名称打包在TorchScript模型中,并在初始模块初始化后立即进行初始化。该模块具有一个get_classesreturn的方法,List[str]可以使用method进行调用Module.runMethod(methodName):

mModule = Module.load(moduleFileAbsoluteFilePath);

IValue getClassesOutput = mModule.runMethod("get_classes");

IValue可以将返回的值转换为IValueusing的java数组,IValue.toList()并使用以下方法处理为字符串数组IValue.toStr():

IValue[] classesListIValue = getClassesOutput.toList();

String[] moduleClasses = new String[classesListIValue.length];

int i = 0;

for (IValue iv : classesListIValue) {

moduleClasses[i++] = iv.toStr();

}

输入的文本将转换为带有UTF-8编码的java字节数组。从该字节数组Tensor.fromBlobUnsigned创建张量dtype=uint8。

byte[] bytes = text.getBytes(Charset.forName("UTF-8"));

final long[] shape = new long[]{1, bytes.length};

final Tensor inputTensor = Tensor.fromBlobUnsigned(bytes, shape);

模型的运行推断与前面的示例相似:

Tensor outputTensor = mModule.forward(IValue.from(inputTensor)).toTensor()

之后,代码处理输出,找到得分最高的类。

你可能感兴趣的:(android设备模型)