以下内容将以 ONNX 格式的大模型在 Android 上的部署与测试为核心,提供一套可运行的示例(基于 Android Studio/Gradle),并结合代码进行详细讲解。最后会给出一些针对在移动设备上部署 ONNX 推理的优化方法和未来建议。
如果模型非常大,通常需进行模型剪枝、量化或其他优化。看前面优化文章
假设我们有一个 NLP 或 CV 的预训练模型(如 GPT、BERT、ResNet、YOLO 等),并且已经将其转换为 model.onnx 文件。
torch.onnx.export
导出;tf2onnx
或 TensorFlow 官方工具进行转换。例如,PyTorch 导出示例(仅供参考):
import torch
import torchvision
# 示例:导出一个 pretrained ResNet18
model = torchvision.models.resnet18(pretrained=True)
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"model.onnx",
input_names=["input"],
output_names=["output"],
opset_version=11
)
导出完成后,会在本地得到一个 model.onnx
文件。
minSdkVersion
建议设置为 21 或更高,以支持大部分 NNAPI / 硬件加速。ONNX Runtime 官方已提供 Android AAR 包,可在 Gradle 中直接添加依赖。
在 app/build.gradle
中,添加类似以下内容:
android {
// 其他配置
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
// 如果需要Kotlin,确保启用合适的编译选项
}
// 在dependencies中添加
dependencies {
implementation 'org.onnxruntime:onnxruntime-android:1.14.1'
}
版本号可根据 ONNX Runtime 官方发布 来更新(此处以 1.14.1 为例)。
假设项目结构如下(只列关键文件):
MyOnnxApp/
├── app/
│ ├── src/
│ │ ├── main/
│ │ │ ├── AndroidManifest.xml
│ │ │ ├── java/com/example/myonnxapp/
│ │ │ │ ├── MainActivity.java
│ │ │ ├── assets/
│ │ │ │ ├── model.onnx (ONNX文件)
│ │ │ ├── res/
│ │ │ │ └── layout/activity_main.xml
│ ├── build.gradle
├── settings.gradle
└── build.gradle
关键点:
model.onnx
放入 app/src/main/assets 目录,以便在运行时能读取模型文件。MainActivity
或其他类中加载模型并执行推理。下面是一个简单的 Java 版本示例(Kotlin 同理),演示如何在 Android 上初始化 ONNX Runtime、加载模型并进行一次推理。这里假设输入是 [1, 3, 224, 224]
的图像张量(如典型的 ImageNet 模型),根据模型实际情况替换。
MainActivity.java
package com.example.myonnxapp;
import androidx.appcompat.app.AppCompatActivity;
import android.os.Bundle;
import android.widget.TextView;
import org.jetbrains.annotations.Nullable;
import org.json.JSONObject;
import org.tensorflow.lite.DataType;
import java.io.IOException;
import java.io.InputStream;
import java.nio.FloatBuffer;
import java.util.Arrays;
import ai.onnxruntime.*;
public class MainActivity extends AppCompatActivity {
private TextView resultText;
private OrtEnvironment env;
private OrtSession session;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
resultText = findViewById(R.id.result_text);
// 初始化 ONNX Runtime
try {
initOnnxRuntime();
// 执行推理
float[] outputScores = runInference();
// 显示结果
resultText.setText("Inference Output: " + Arrays.toString(outputScores));
} catch (Exception e) {
e.printStackTrace();
resultText.setText("Error: " + e.getMessage());
}
}
private void initOnnxRuntime() throws OrtException {
// 创建 ORT 环境
env = OrtEnvironment.getEnvironment();
// 构建 SessionOptions
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
// 可选: 使用 CPU 或 NNAPI 等加速,如果需要,可启用如下:
// sessionOptions.addNnapi();
// 从assets加载模型
try {
InputStream modelStream = getAssets().open("model.onnx");
byte[] modelBytes = new byte[modelStream.available()];
modelStream.read(modelBytes);
session = env.createSession(modelBytes, sessionOptions);
} catch (IOException ioException) {
throw new RuntimeException("Failed to load model from assets", ioException);
}
}
private float[] runInference() throws OrtException {
// 准备输入张量
// 假设输入大小 [1, 3, 224, 224],数据类型 float32
float[] inputData = new float[1 * 3 * 224 * 224];
// 这里示例: 全部填充随机值 or 0.5f
// 实际中可来自图像预处理
for (int i = 0; i < inputData.length; i++) {
inputData[i] = 0.5f;
}
// ONNX Runtime需要将Java数组包装成OnnxTensor
long[] inputShape = new long[]{1, 3, 224, 224};
OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputData, inputShape);
// 准备输入名 (与导出时的 input_names 对应)
String inputName = session.getInputNames().iterator().next();
// 运行会话
OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));
// 假设输出名为 "output",或者取 getOutputNames() 的第一个
String outputName = session.getOutputNames().iterator().next();
float[][] outputRaw = (float[][]) result.get(0).getValue();
// 此时 outputRaw 可能为 [1, num_classes],示例中只返回数组
float[] outputScores = outputRaw[0];
// 释放资源
inputTensor.close();
result.close();
return outputScores;
}
@Override
protected void onDestroy() {
super.onDestroy();
// 关闭 Session 和 Env,避免内存泄漏
if (session != null) {
try {
session.close();
} catch (OrtException e) {
e.printStackTrace();
}
}
if (env != null) {
try {
env.close();
} catch (OrtException e) {
e.printStackTrace();
}
}
}
}
activity_main.xml
以上示例中:
model.onnx
。session.getOutputNames()
可获取)来取出对应的张量。model.onnx
存在并可读取。outputScores
做 argmax,得到类别索引,并在界面中显示。在真实应用中,你可以:
对于大模型,在移动端或嵌入式设备上的推理可能存在 内存、速度、功耗 等瓶颈。可从以下几个角度进行优化。
ONNX Runtime 也支持量化后的模型推理,但需确保量化算子在该版本的 ORT for Android 上可用。
sessionOptions.addNnapi()
启用 Android NNAPI,让系统层面自动调度 GPU/NPU。通过以上步骤,大家可以将 ONNX 格式的大模型在 Android 设备上进行推理测试,并结合 ONNX Runtime 的接口进行快速部署。对于在移动端部署大模型,建议在精度与资源之间做充分的权衡,并利用量化、剪枝、蒸馏等方法进行模型优化。结合硬件加速与工具链的进一步发展,移动端也能承载越来越强大的 AI 能力,满足更多的实际业务需求。
【哈佛博后带小白玩转机器学习】 哔哩哔哩_bilibili
总课时超400+,时长75+小时