【大模型开发】ONNX 格式的大模型在 Android 上的部署与测试

以下内容将以 ONNX 格式的大模型在 Android 上的部署与测试为核心,提供一套可运行的示例(基于 Android Studio/Gradle),并结合代码进行详细讲解。最后会给出一些针对在移动设备上部署 ONNX 推理的优化方法和未来建议。


目录

  1. 整体流程概述
  2. 准备工作
    2.1 ONNX 模型准备
    2.2 Android 项目准备
  3. 在 Android 上使用 ONNX Runtime
    3.1 添加依赖
    3.2 项目结构说明
    3.3 代码示例
  4. 运行与测试示例
  5. 优化方法
    5.1 模型压缩与量化
    5.2 算子融合与图优化
    5.3 硬件加速接口
  6. 未来建议

1. 整体流程概述

  1. 模型转换:将训练好的大模型从 PyTorch/TensorFlow 等框架导出为 ONNX 格式。
  2. 接入 ONNX Runtime for Android:在 Android 应用中,通过 onnxruntime-android 库进行模型推理。
  3. 编写推理逻辑:在代码中加载 ONNX 模型文件,准备输入张量,执行推理,并获取输出。
  4. 部署到手机并测试:将应用安装到 Android 设备上,测试推理速度、准确率等指标。

如果模型非常大,通常需进行模型剪枝、量化或其他优化。看前面优化文章


2. 准备工作

2.1 ONNX 模型准备

假设我们有一个 NLP 或 CV 的预训练模型(如 GPT、BERT、ResNet、YOLO 等),并且已经将其转换为 model.onnx 文件。

  • 如果你使用的是 PyTorch,可以通过 torch.onnx.export 导出;
  • 如果是 TensorFlow,可以借助 tf2onnx 或 TensorFlow 官方工具进行转换。

例如,PyTorch 导出示例(仅供参考):

import torch
import torchvision

# 示例:导出一个 pretrained ResNet18
model = torchvision.models.resnet18(pretrained=True)
model.eval()

dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
    model, 
    dummy_input, 
    "model.onnx",
    input_names=["input"], 
    output_names=["output"], 
    opset_version=11
)

导出完成后,会在本地得到一个 model.onnx 文件。

2.2 Android 项目准备

  1. Android Studio:建议版本 4.0 以上。
  2. Android Gradle Plugin:建议版本与 Android Studio 保持一致。
  3. 最低 SDK 要求:一般 minSdkVersion 建议设置为 21 或更高,以支持大部分 NNAPI / 硬件加速。
  4. 设备:至少需要拥有 ARM64 架构、足够的 RAM;若模型较大,需要高端手机或者减少模型规模。

3. 在 Android 上使用 ONNX Runtime

3.1 添加依赖

ONNX Runtime 官方已提供 Android AAR 包,可在 Gradle 中直接添加依赖。
app/build.gradle 中,添加类似以下内容:

android {
    // 其他配置
    compileOptions {
        sourceCompatibility JavaVersion.VERSION_1_8
        targetCompatibility JavaVersion.VERSION_1_8
    }
    // 如果需要Kotlin,确保启用合适的编译选项
}

// 在dependencies中添加
dependencies {
    implementation 'org.onnxruntime:onnxruntime-android:1.14.1'
}

版本号可根据 ONNX Runtime 官方发布 来更新(此处以 1.14.1 为例)。

3.2 项目结构说明

假设项目结构如下(只列关键文件):

MyOnnxApp/
  ├── app/
  │   ├── src/
  │   │   ├── main/
  │   │   │   ├── AndroidManifest.xml
  │   │   │   ├── java/com/example/myonnxapp/
  │   │   │   │   ├── MainActivity.java
  │   │   │   ├── assets/
  │   │   │   │   ├── model.onnx   (ONNX文件)
  │   │   │   ├── res/
  │   │   │   │   └── layout/activity_main.xml
  │   ├── build.gradle
  ├── settings.gradle
  └── build.gradle

关键点:

  • model.onnx 放入 app/src/main/assets 目录,以便在运行时能读取模型文件。
  • MainActivity 或其他类中加载模型并执行推理。

3.3 代码示例

下面是一个简单的 Java 版本示例(Kotlin 同理),演示如何在 Android 上初始化 ONNX Runtime、加载模型并进行一次推理。这里假设输入是 [1, 3, 224, 224] 的图像张量(如典型的 ImageNet 模型),根据模型实际情况替换。

3.3.1 MainActivity.java
package com.example.myonnxapp;

import androidx.appcompat.app.AppCompatActivity;
import android.os.Bundle;
import android.widget.TextView;

import org.jetbrains.annotations.Nullable;
import org.json.JSONObject;
import org.tensorflow.lite.DataType;

import java.io.IOException;
import java.io.InputStream;
import java.nio.FloatBuffer;
import java.util.Arrays;

import ai.onnxruntime.*;

public class MainActivity extends AppCompatActivity {

    private TextView resultText;
    private OrtEnvironment env;
    private OrtSession session;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);
        resultText = findViewById(R.id.result_text);

        // 初始化 ONNX Runtime
        try {
            initOnnxRuntime();
            // 执行推理
            float[] outputScores = runInference();
            // 显示结果
            resultText.setText("Inference Output: " + Arrays.toString(outputScores));
        } catch (Exception e) {
            e.printStackTrace();
            resultText.setText("Error: " + e.getMessage());
        }
    }

    private void initOnnxRuntime() throws OrtException {
        // 创建 ORT 环境
        env = OrtEnvironment.getEnvironment();
        // 构建 SessionOptions
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
        // 可选: 使用 CPU 或 NNAPI 等加速,如果需要,可启用如下:
        // sessionOptions.addNnapi();
        
        // 从assets加载模型
        try {
            InputStream modelStream = getAssets().open("model.onnx");
            byte[] modelBytes = new byte[modelStream.available()];
            modelStream.read(modelBytes);
            session = env.createSession(modelBytes, sessionOptions);
        } catch (IOException ioException) {
            throw new RuntimeException("Failed to load model from assets", ioException);
        }
    }

    private float[] runInference() throws OrtException {
        // 准备输入张量
        // 假设输入大小 [1, 3, 224, 224],数据类型 float32
        float[] inputData = new float[1 * 3 * 224 * 224];
        // 这里示例: 全部填充随机值 or 0.5f 
        // 实际中可来自图像预处理
        for (int i = 0; i < inputData.length; i++) {
            inputData[i] = 0.5f; 
        }

        // ONNX Runtime需要将Java数组包装成OnnxTensor
        long[] inputShape = new long[]{1, 3, 224, 224};
        OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputData, inputShape);

        // 准备输入名 (与导出时的 input_names 对应)
        String inputName = session.getInputNames().iterator().next();

        // 运行会话
        OrtSession.Result result = session.run(Collections.singletonMap(inputName, inputTensor));

        // 假设输出名为 "output",或者取 getOutputNames() 的第一个
        String outputName = session.getOutputNames().iterator().next();
        float[][] outputRaw = (float[][]) result.get(0).getValue();

        // 此时 outputRaw 可能为 [1, num_classes],示例中只返回数组
        float[] outputScores = outputRaw[0];

        // 释放资源
        inputTensor.close();
        result.close();
        return outputScores;
    }

    @Override
    protected void onDestroy() {
        super.onDestroy();
        // 关闭 Session 和 Env,避免内存泄漏
        if (session != null) {
            try {
                session.close();
            } catch (OrtException e) {
                e.printStackTrace();
            }
        }
        if (env != null) {
            try {
                env.close();
            } catch (OrtException e) {
                e.printStackTrace();
            }
        }
    }
}
3.3.2 activity_main.xml


    
    

以上示例中:

  1. initOnnxRuntime(): 初始化 ONNX Runtime 环境,加载 assets 中的 model.onnx
  2. runInference(): 构造一个假输入,并调用 session.run() 获取推理结果。
  3. 若要在手机上处理实际图像或文本,需要在推理前进行预处理(图像缩放、归一化,或分词、构造输入 ID 等)。
  4. 若模型输出不止一个,需根据真实的模型输出名(session.getOutputNames()可获取)来取出对应的张量。

4. 运行与测试示例

  1. 将 ONNX 模型放入 assets:确保 model.onnx 存在并可读取。
  2. 编译并运行:连接 Android 设备(或使用模拟器,但大型模型推理更适合真机),点击 Run,查看日志或界面的输出信息。
  3. 检查结果:若模型是分类网络,可以对 outputScores 做 argmax,得到类别索引,并在界面中显示。

在真实应用中,你可以:

  • 通过手机摄像头或相册加载图像 -> 预处理 -> 输入模型 -> 输出预测结果。
  • 如果是语言模型或其他结构,也需要对应的数据预处理与后处理流程。

5. 优化方法

对于大模型,在移动端或嵌入式设备上的推理可能存在 内存、速度、功耗 等瓶颈。可从以下几个角度进行优化。

5.1 模型压缩与量化

  1. Post-Training Quantization:将 FP32 转为 INT8;配合校准数据,可显著减少模型体积并提升推理速度。
  2. Knowledge Distillation:训练一个尺寸更小的“学生模型”,在移动端部署。
  3. Pruning / Sparsity:移除不重要的权重或通道(需要硬件和库支持稀疏加速)。

ONNX Runtime 也支持量化后的模型推理,但需确保量化算子在该版本的 ORT for Android 上可用。

5.2 算子融合与图优化

  1. ONNX Graph Optimizer:在导出后,可对 ONNX 图进行融合、消除冗余节点等处理。ONNX Runtime 默认会进行部分优化。
  2. 减少不必要的操作:将预处理/后处理逻辑尽量简化,或者放到端上原生代码里去执行。

5.3 硬件加速接口

  1. NNAPI(Android):可通过 sessionOptions.addNnapi() 启用 Android NNAPI,让系统层面自动调度 GPU/NPU。
  2. GPU Delegate:ONNX Runtime 提供部分 GPU 后端支持,但兼容度可能不及 TensorFlow Lite GPU Delegate。
  3. DSP / NPU 厂商库:某些芯片厂商提供自定义加速库,可将 ONNX 模型进一步编译成特定格式。

6. 未来建议

  1. 更灵活的混合部署
    对于超大模型,可以考虑在云端服务器执行大部分推理或粗特征提取,只在移动端做小模型的精调或快速推理。
  2. 分片与流式推理
    在内存特别有限时,可以将模型分成多段,分批加载计算。
  3. 持续关注 ONNX Runtime 更新
    随着新版本的推出,硬件加速和量化等特性会不断完善。
  4. 结合边缘专用硬件
    如果有专用设备(如 Google Coral TPU、NVIDIA Jetson NX、ARM Ethos 等),可考虑将 ONNX 模型部署到相应 SDK 中,大幅提升性能。
  5. 结合自动化 NAS
    如果对精度和性能要求极高,可使用神经网络架构搜索(NAS)寻找更适合移动设备的模型结构,在保持效果的同时极大降低推理成本。

总结

通过以上步骤,大家可以将 ONNX 格式的大模型在 Android 设备上进行推理测试,并结合 ONNX Runtime 的接口进行快速部署。对于在移动端部署大模型,建议在精度与资源之间做充分的权衡,并利用量化、剪枝、蒸馏等方法进行模型优化。结合硬件加速与工具链的进一步发展,移动端也能承载越来越强大的 AI 能力,满足更多的实际业务需求。

哈佛博后带小白玩转机器学习】 哔哩哔哩_bilibili

总课时超400+,时长75+小时

你可能感兴趣的:(大模型技术开发与实践,哈佛博后带你玩转机器学习,深度学习,android,大模型部署,本地推理引擎,大模型开发,机器学习,边缘设备)