安卓调用onnx模型并计算

安卓平台可以通过调用onnx模型来进行计算,这为移动设备提供了更多的计算能力和应用场景。通过使用onnx模型,安卓设备可以进行复杂的计算任务,例如图像识别、语音识别等。这为移动应用的功能和性能提升提供了新的可能性。同时,开发者可以利用onnx模型来开发更加智能和高效的安卓应用,为用户提供更好的体验。总的来说,安卓调用onnx模型并进行计算的能力为移动设备的发展带来了新的机遇和挑战。

 

依赖

build.gradle

plugins {
    id 'com.android.application'
}

repositories {
    jcenter()
    maven {
        url "https://oss.sonatype.org/content/repositories/snapshots"
    }
}

android {
    signingConfigs {
        release {
            storeFile file('myapp.keystore')
            storePassword '123456'
            keyAlias 'myapp'
            keyPassword '123456'
        }
    }
    packagingOptions {
        pickFirst 'lib/arm64-v8a/libc++_shared.so'
    }
    configurations {
        extractForNativeBuild
    }
    compileSdkVersion 28
    buildToolsVersion "30.0.3"

    defaultConfig {
        applicationId "com.mobvoi.myapp"
        minSdkVersion 21
        targetSdkVersion 28
        versionCode 1
        versionName "1.0"

        testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
        externalNativeBuild {
            cmake {
                targets "myapp", "decoder_main"
                cppFlags "-std=c++11", "-DC10_USE_GLOG", "-DC10_USE_MINIMAL_GLOG", "-DANDROID", "-Wno-c++11-narrowing", "-fexceptions"
            }
        }

        ndkVersion '21.3.6528147'
        ndk {
            abiFilters 'arm64-v8a'
        }
    }

    buildTypes {
        release {
            minifyEnabled false
            signingConfig signingConfigs.release
            proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
        }
    }
    externalNativeBuild {
        cmake {
            path "src/main/cpp/CMakeLists.txt"
        }
    }
    compileOptions {
        sourceCompatibility JavaVersion.VERSION_1_8
        targetCompatibility JavaVersion.VERSION_1_8
    }
    ndkVersion '21.3.6528147'
}

dependencies {

    implementation 'androidx.appcompat:appcompat:1.2.0'
    implementation 'com.google.android.material:material:1.2.1'
    implementation 'androidx.constraintlayout:constraintlayout:2.0.4'
    testImplementation 'junit:junit:4.+'
    androidTestImplementation 'androidx.test.ext:junit:1.1.2'
    androidTestImplementation 'androidx.test.espresso:espresso-core:3.3.0'

    implementation 'org.pytorch:pytorch_android:1.10.0'
    extractForNativeBuild 'org.pytorch:pytorch_android:1.10.0'

    implementation group: 'com.microsoft.onnxruntime', name: 'onnxruntime-android', version: '1.15.1'
}

task extractAARForNativeBuild {
    doLast {
        configurations.extractForNativeBuild.files.each {
            def file = it.absoluteFile
            copy {
                from zipTree(file)
                into "$buildDir/$file.name"
                include "headers/**"
                include "jni/**"
            }
        }
    }
}

tasks.whenTaskAdded { task ->
    if (task.name.contains('externalNativeBuild')) {
        task.dependsOn(extractAARForNativeBuild)
    }
}

准备好onnx放在assert目录下

安卓调用onnx模型并计算_第1张图片

api介绍

地址

api文档icon-default.png?t=N7T8https://javadoc.io/doc/com.microsoft.onnxruntime/onnxruntime/latest/index.html

常用的api

api 作用
OrtEnvironment.getEnvironment()
创建onnx上下文的运行环境
new OrtSession.SessionOptions()
创建会话(配置)
environment.createSession(bytes, options)
创建会话,第一个参数是模型数据,第二个是配置的参数
LongBuffer.wrap(inputValues)
将输入转换成onnx识别的输入,输入是模型识别的数据
OnnxTensor.createTensor(environment, wrap, new long[]{1, inputValues.length})
创建tensor,第一个参数是上面定义的环境,第二个参数是输入转换成模型的格式,第三个根据实际设置,为入参的矩阵格式
session.run(map)
推理,map是整合起来的数据
(long[][]) output.get(1).getValue()
获取推理结果,这里以二维数组为例

使用案例

private String getOnnx(String text) {
        OrtEnvironment environment = OrtEnvironment.getEnvironment();
        AssetManager assetManager = getAssets();
        try {
            // 创建会话
            OrtSession.SessionOptions options = new OrtSession.SessionOptions();

            // 读取模型
            InputStream stream = assetManager.open("youonnx.onnx");
            ByteArrayOutputStream byteStream = new ByteArrayOutputStream();

            byte[] buffer = new byte[4096];
            int bytesRead;
            while ((bytesRead = stream.read(buffer)) != -1) {
                byteStream.write(buffer, 0, bytesRead);
            }

            byteStream.flush();
            byte[] bytes = byteStream.toByteArray();
            OrtSession session = environment.createSession(bytes, options);
            String vocab = "vocab";
            String puncVocab = "punc_vocab";

            Map vocabMap = getFormFile(vocab, new String[]{"", ""});
            Map puncVocabMap = getFormFile(vocab, new String[]{" "});

            DataSet.NoPuncTextDataset dataset = new DataSet.NoPuncTextDataset(vocabMap, puncVocabMap);

            List list = dataset.word2seq(text);

            // 准备输入数据
            long[] inputValues = new long[list.size()];
            for (int i = 0; i < list.size(); i++) {
                inputValues[i] = list.get(i);
            }

            LongBuffer wrap = LongBuffer.wrap(inputValues);
            OnnxTensor inputTensor = OnnxTensor.createTensor(environment, wrap, new long[]{1, inputValues.length});

            long[] len = new long[]{inputValues.length};
            LongBuffer wrap2 = LongBuffer.wrap(len);
            OnnxTensor inputTensor_len = OnnxTensor.createTensor(environment, wrap2, new long[]{1});
            // 准备数据
            Map map = new HashMap<>();
            map.put("inputs", inputTensor);
            map.put("inputs_len", inputTensor_len);

            // 运行推理
            OrtSession.Result output = session.run(map);

            // 获取输出结果
            long[][] value = (long[][]) output.get(1).getValue();
            // 处理输出结果
            // todo
            session.close();
            return "you_answer"
        } catch (IOException | OrtException e) {
            throw new RuntimeException(e);
        }
    }

通过调用此函数,可以实现安卓调用onnx

你可能感兴趣的:(安卓,机器学习,android,java,后端,机器学习,onnx)