Android设备上部署Pytorch,实现性别识别,男女分类

上一篇文章《Pytorch实现性别识别,男女分类》
我们用pytorch实现了性别识别神经网络的训练和测试,这篇文章我们来介绍如何把训练好的模型迁移到Android设备上。

一、Android上引入pytorch

在app module下的build.gradle上加上

    implementation 'org.pytorch:pytorch_android:1.3.0'
    implementation 'org.pytorch:pytorch_android_torchvision:1.3.0'

二、把训练好的模型net.pt放到assets目录下

Android设备上部署Pytorch,实现性别识别,男女分类_第1张图片

三、编写代码

3.1、读取assets下的模型文件

private String assetFilePath(Context context, String assetName) {
        File file = new File(context.getFilesDir(), assetName);
        try (InputStream is = context.getAssets().open(assetName)) {
            try (OutputStream os = new FileOutputStream(file)) {
                byte[] buffer = new byte[4 * 1024];
                int read;
                while ((read = is.read(buffer)) != -1) {
                    os.write(buffer, 0, read);
                }
                os.flush();
            }
            return file.getAbsolutePath();
        } catch (IOException e) {
            Log.e("pytorchandroid", "Error process asset " + assetName + " to file path");
        }
        return null;
    }

3.2、封装pytorch工具类


import android.graphics.Bitmap;

import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;

public class Classifier {
    //类别
    public static final String[] SEXS = new String[]{"男","女"};

    Module model;
    float[] mean = {0.485f, 0.456f, 0.406f};
    float[] std = {0.229f, 0.224f, 0.225f};

    /**
     * 加载assets中的模型
     * @param modelPath
     */
    public Classifier(String modelPath){
        model = Module.load(modelPath);
    }

    /**
     * 传入图片预测性别
     * @param bitmap
     * @param size 规定传入的图片要符合一个大小标准,这里是32*32
     * @return
     */
    public String predict(Bitmap bitmap, int size){
        Tensor tensor = preprocess(bitmap,size);
        IValue inputs = IValue.from(tensor);
        Tensor outputs = model.forward(inputs).toTensor();
        float[] scores = outputs.getDataAsFloatArray();
        int classIndex = argMax(scores);
        return SEXS[classIndex];
    }

    /**
     * 调整图片大小
     * @param bitmap
     * @param size
     * @return
     */
    public Tensor preprocess(Bitmap bitmap, int size){
        bitmap = Bitmap.createScaledBitmap(bitmap,size,size,false);
        return TensorImageUtils.bitmapToFloat32Tensor(bitmap,this.mean,this.std);
    }

    /**
     * 计算最大的概率
     * @param inputs
     * @return
     */
    public int argMax(float[] inputs){
        int maxIndex = -1;
        float maxvalue = 0.0f;
        for (int i = 0; i < inputs.length; i++){
            if(inputs[i] > maxvalue) {
                maxIndex = i;
                maxvalue = inputs[i];
            }
        }
        return maxIndex;
    }

}

其中调整Bitmap大小的方法很重要,否则会报错Caused by: java.lang.RuntimeException: shape '[-1, 400]' is invalid for input of size 150544 The above operation failed in interpreter.

3.3、调用模型预测

//这里的size要根据模型的需要进行改变,本模型需要32*32大小的图片
String pred = classifier.predict(bitmap, 32);

最终运行效果如下:
Android设备上部署Pytorch,实现性别识别,男女分类_第2张图片
如果您想要完整代码移步这里:https://download.csdn.net/download/zhangdongren/12358642

你可能感兴趣的:(android开发,pytorch)