MXNet 特征点提取基本流程

MXNet 特征点提取基本流程

以 Android 调用 MXNet 为例:

开源 MXNet 代码:incubator-mxnet

其中 Android 部分代码路径:incubator-mxnet/amalagamation/jni

文件 说明
org_dmlc_mxnet_Predictor.h MXNet JNI 接口声明文件
predictor.cc MXNet JNI 接口实现文件
org/dmlc/mxnet/MxnetException.java MXNet JNI 接口相关的 Java 端报错文件(示例)
org/dmlc/mxnet/Predictor.java MXNet JNI 接口相关的 Java 端接口文件(示例)

分析一下 predictor.cc 文件的每个接口功能:

MXNet 的 JNI 接口

查看 org/dmlc/mxnet/Predictor.java 文件可以知道 MXNet 的 Android 端基本接口只有 4 个。


private native static long createPredictor(byte[] symbol, byte[] params, int devType, int devId, String[] keys, int[][] shapes);
private native static void nativeFree(long handle);
private native static float[] nativeGetOutput(long handle, int index);
private native static void nativeForward(long handle, String key, float[] input);

主要功能:

接口 描述
createPredictor 初始化 MXNet predictor
nativeFree 释放 MXNet 资源(关闭 MXNet 功能)
nativeGetOutput 获取特征点信息
nativeForward 输入需要提取特征的元素数据

createPredictor


/*
 * Class:     org_dmlc_mxnet_Predictor
 * Method:    createPredictor
 * Signature: ([B[BII[Ljava/lang/String;[[I)J
 */
JNIEXPORT jlong JNICALL Java_org_dmlc_mxnet_Predictor_createPredictor
  (JNIEnv *, jclass, jbyteArray symbol, jbyteArray params, jint devType, jint devId, jobjectArray keys, jobjectArray shapes);

创建 MXNet predictor (预测器),用于对图片数据提取特征。

参数名 JNI 类型 Java 类型 说明
symbol jbyteArray byte[] 模型 symbol 数据(字节流)
params jbyteArray byte[] 模型 params 数据(字节流)
devType jint int 机器学习使用的硬件类型,支持 CPU(1), GPU(2), CPU Pinned(3) 等
devId jint int predictor 的设备 id (用于区分其它 MXNet 成员)
keys jobjectArray String[] 输入参数的名称,对于 feedforward 是 {"data"}
shapes jobjectArray int[][] 多组输入节点的 shape 数据
返回值 jlong long 返回创建的 predictor 的句柄(通过该句柄使用不同的 MXNet predictor)

nativeFree


/*
 * Class:     org_dmlc_mxnet_Predictor
 * Method:    nativeFree
 * Signature: (J)V
 */
JNIEXPORT void JNICALL Java_org_dmlc_mxnet_Predictor_nativeFree
  (JNIEnv *, jclass, jlong handle);

用于释放对应的 MXNet predictor 数据,回收资源。

参数名 JNI 类型 Java 类型 说明
handle jlong long predictor 句柄,用于找到对应 MXNet 数据进行释放

nativeGetOutput


/*
 * Class:     org_dmlc_mxnet_Predictor
 * Method:    nativeGetOutput
 * Signature: (JI)[F
 */
JNIEXPORT jfloatArray JNICALL Java_org_dmlc_mxnet_Predictor_nativeGetOutput
  (JNIEnv *, jclass, jlong handle, jint index);

获取特征点数据(已经经过机器学习根据模型提取特征点)。

参数名 JNI 类型 Java 类型 说明
handle jlong long predictor 句柄(同上)
index jint int shape 数据索引,获取第 index 组 shape 数据(MXNet 支持多种检测 shape)

nativeForward


/*
 * Class:     org_dmlc_mxnet_Predictor
 * Method:    nativeForward
 * Signature: (JLjava/lang/String;[F)V
 */
JNIEXPORT void JNICALL Java_org_dmlc_mxnet_Predictor_nativeForward
  (JNIEnv *, jclass, jlong handle, jstring key, jfloatArray input);

输入图片数据,用于提取特征点。

参数名 JNI 类型 Java 类型 说明
handle jlong long predictor 句柄(同上)
key jstring String 设置输入数据的参数名称
input jfloatArray float[] 图片数据,注意把 [ Y, X, RGB ] 的维度转为 [ RGB, Y, X ] 的维度

其中 input (float[]) 数据需要 RGB 数据(不需要 Alpha 透明度),而且还要进行维度转换:

[ 行,列,色深(RGB) ] 转为 [ 色深(RGB),行,列 ]。

另外不同的模型对 RGB 值会有一些偏移,也需要注意不同模型的参数。

参考代码如下:


public float[] inputFromImage(Bitmap[] bmps, float meanR, float meanG, float meanB) {
    if (bmps.length == 0) return null;

    int width = bmps[0].getWidth();
    int height = bmps[0].getHeight();
    float[] buf = new float[height * width * 3 * bmps.length];
    for (int x=0; x

你可能感兴趣的:(MXNet 特征点提取基本流程)