基于Pytorch Mobile在安卓手机端部署深度估计模型

基于Pytorch Mobile在安卓手机端部署深度估计模型

  • 1.选取torch版本的深度估计模型
  • 2.修改模型实现代码
  • 3.Pytorch生成ptl模型
  • 4.安卓端部署代码
  • 5.实验配置
  • 6.手机端效果展示

1.选取torch版本的深度估计模型

深度估计模型这里选择torch版本的Monodepth,代码地址:https://github.com/OniroAI/MonoDepth-PyTorch,文章链接:https://arxiv.org/abs/1609.03677。
建议在实现本文之前,先跑通torch的官方教程,https://github.com/pytorch/android-demo-app,本文建立在能跑通示例中语义分割模型的基础上。

Monodepth代码中需要使用的部分:
基于Pytorch Mobile在安卓手机端部署深度估计模型_第1张图片

2.修改模型实现代码

整个网络设计中只使用pytorch定义的方法或python原生的语法,不能使用其他第三方框架如Numpy,Opencv。该例中,模型定义在models_resnet.py中,以Resnet18_md为例,需要修改的部分为:
1.代码中使用numpy实现的操作用原生的python库进行代替:

class conv(nn.Module):
    def __init__(self, num_in_layers, num_out_layers, kernel_size, stride):
        super(conv, self).__init__()
        self.kernel_size = kernel_size
        self.conv_base = nn.Conv2d(num_in_layers, num_out_layers, kernel_size=kernel_size, stride=stride)
        self.normalize = nn.BatchNorm2d(num_out_layers)

    def forward(self, x):
        p = int(np.floor((self.kernel_size-1)/2))   #使用Numpy实现需要修改 np.floor ==》 math.floor ,即int(math.floor((self.kernel_size-1)/2))
        p2d = (p, p, p, p)
        x = self.conv_base(F.pad(x, p2d))
        x = self.normalize(x)
        return F.elu(x, inplace=True)
class maxpool(nn.Module):
    def __init__(self, kernel_size):
        super(maxpool, self).__init__()
        self.kernel_size = kernel_size

    def forward(self, x):
        p = int(np.floor((self.kernel_size-1) / 2))   #使用Numpy实现需要修改 np.floor ==》 math.floor
        p2d = (p, p, p, p)
        return F.max_pool2d(F.pad(x, p2d), self.kernel_size, stride=2)

2.代码中过期的pytorch函数重新实现,因为Pytorch Mobile需要的pytorch版本很新,因此有些旧的实现已经在新版本中被修改:

self.udisp4 = nn.functional.interpolate(self.disp4, scale_factor=2, mode='bilinear', align_corners=True) 
修改为:
udisp4 = nn.functional.interpolate(disp4, scale_factor=2., mode='bilinear', align_corners=True)

scale_factor在新版本中只能是浮点数,并且udisp4和disp4并没有在__init__()中定义为属性,因此在这里去掉self,否则Pytorch Mobile编译会报错。

3.修改输出:

return self.disp1, self.disp2, self.disp3, self.disp4
修改为
return disp1

这里输出四个视差是为了在多尺度下做Loss,在迁移到手机上时我们只用选择最大的尺度输出即可。

3.Pytorch生成ptl模型

import torch
from PIL import Image
from torchvision import transforms
from torch.utils.mobile_optimizer import optimize_for_mobile
from utils import get_model


image = Image.open("O:\\xxx\\0.jpg")   #读取一张图片用来测试输出尺寸是否满足预期
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])    
input = preprocess(image) #转换成Tensor
model = get_model('resnet18_md', 3, True) #获取模型,模型的定义代码为models_resnet.py。
model.load_state_dict(torch.load("O:\\xxx\\monodepth_resnet18_001.pth")) #读取模型的预训练参数,预训练文件下载地址https://github.com/OniroAI/MonoDepth-PyTorch
input = input.unsqueeze(0)
output = model(input)  #使用模型处理一张图片
print(output.shape) #测试尺度是否正常
model.eval()

scripted_module = torch.jit.script(model) #模型的转换!!!此处是重点,转换后的ptl模型就可以在安卓端运行
optimized_scripted_module = optimize_for_mobile(scripted_module) #针对移动端的特殊优化可以加快推理速度

# Export full jit version model (not compatible with lite interpreter)
scripted_module.save("monodepth.pt")
# Export lite interpreter version model (compatible with lite interpreter)
scripted_module._save_for_lite_interpreter("monodepth_scripted.ptl")
# using optimized lite interpreter model makes inference about 60% faster than the non-optimized lite interpreter model, which is about 6% faster than the non-optimized full jit model
optimized_scripted_module._save_for_lite_interpreter("monodepth_scripted_optimized.ptl") #根据官网描述,这种方式得到的模型推理速度最快比monodepth.pt快60%比monodepth_scripted.ptl快6%

4.安卓端部署代码

package org.pytorch.imagesegmentation;

import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.os.SystemClock;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.ProgressBar;

import androidx.appcompat.app.AppCompatActivity;

import org.pytorch.IValue;
import org.pytorch.LiteModuleLoader;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;


public class MainActivity extends AppCompatActivity implements Runnable {
    private ImageView mImageView;
    private Button mButtonSegment;
    private ProgressBar mProgressBar;
    private Bitmap mBitmap = null;
    private Module mModule = null;
    private int mImagename = 0;


    public static String assetFilePath(Context context, String assetName) throws IOException {
        File file = new File(context.getFilesDir(), assetName);
        if (file.exists() && file.length() > 0) {
            file.delete();
        }

        try (InputStream is = context.getAssets().open(assetName)) {
            try (OutputStream os = new FileOutputStream(file)) {
                byte[] buffer = new byte[4 * 1024];
                int read;
                while ((read = is.read(buffer)) != -1) {
                    os.write(buffer, 0, read);
                }
                os.flush();
            }
            return file.getAbsolutePath();
        }
    }

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_main);

        try {
            mBitmap = BitmapFactory.decodeStream(getAssets().open(mImagename + ".jpg"));
        } catch (IOException e) {
            Log.e("DepthEstimation", "Error reading assets", e);
            finish();
        }

        mImageView = findViewById(R.id.imageView);
        mImageView.setImageBitmap(mBitmap);

        final Button buttonRestart = findViewById(R.id.restartButton);
        buttonRestart.setOnClickListener(new View.OnClickListener() {
            public void onClick(View v) {
                if(mImagename > 8) mImagename = 0;
                try {
                    mBitmap = BitmapFactory.decodeStream(getAssets().open(mImagename + ".jpg"));
                    mImagename++;
                    mImageView.setImageBitmap(mBitmap);
                } catch (IOException e) {
                    Log.e("DepthEstimation", "Error reading assets", e);
                    finish();
                }
            }
        });


        mButtonSegment = findViewById(R.id.segmentButton);
        mProgressBar = (ProgressBar) findViewById(R.id.progressBar);
        mButtonSegment.setOnClickListener(new View.OnClickListener() {
            public void onClick(View v) {
                mButtonSegment.setEnabled(false);
                mProgressBar.setVisibility(ProgressBar.VISIBLE);
                mButtonSegment.setText(getString(R.string.run_model));

                Thread thread = new Thread(MainActivity.this);
                thread.start();
            }
        });

        try {
            mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), "monodepth_scripted_optimized.ptl"));
        } catch (IOException e) {
            Log.e("DepthEstimation", "Error reading assets", e);
            finish();
        }
    }

    @Override
    public void run() {
        final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(mBitmap,
                TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
        final float[] inputs = inputTensor.getDataAsFloatArray();

        final long startTime = SystemClock.elapsedRealtime();
        IValue outTensors = mModule.forward(IValue.from(inputTensor));
        final long inferenceTime = SystemClock.elapsedRealtime() - startTime;
        Log.d("DepthEstimation",  "inference time (ms): " + inferenceTime);
        System.out.println(inferenceTime);
        final Tensor outputTensor = outTensors.toTensor();
        final float[] intValues = outputTensor.getDataAsFloatArray();

        int width = mBitmap.getWidth();
        int height = mBitmap.getHeight();
        ArrayList<Float> arralist = new ArrayList<>();
        for (int i = 0 ; i< intValues.length ; i++){
            arralist.add(intValues[i]);
        }
        final Bitmap bitmap = arrayFlotToBitmap(arralist, width, height);
        runOnUiThread(new Runnable() {
            @Override
            public void run() {
                mImageView.setImageBitmap(bitmap);
                mButtonSegment.setEnabled(true);
                mButtonSegment.setText(getString(R.string.segment));
                mProgressBar.setVisibility(ProgressBar.INVISIBLE);
            }
        });
    }

    private static Bitmap arrayFlotToBitmap(List<Float> floatArray, int width, int height){

        byte alpha = (byte) 255 ;

        Bitmap bmp = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888) ;

        ByteBuffer byteBuffer = ByteBuffer.allocate(width*height*4*3) ;

        float Maximum = Collections.max(floatArray);
        float minmum = Collections.min(floatArray);
        float delta = Maximum - minmum ;

        int i = 0 ;
        for (float value : floatArray){
            byte temValue = (byte) ((byte) ((((value-minmum)/delta)*255)));
            byteBuffer.put(4*i, temValue) ;
            byteBuffer.put(4*i+1, temValue) ;
            byteBuffer.put(4*i+2, temValue) ;
            byteBuffer.put(4*i+3, alpha) ;
            i++ ;
        }
        bmp.copyPixelsFromBuffer(byteBuffer) ;
        return bmp ;
    }

}

实现参考了pytorch官方的语义分割实例,https://github.com/pytorch/android-demo-app/tree/master/ImageSegmentation。其中有两个主要修改:

    public static String assetFilePath(Context context, String assetName) throws IOException {
        File file = new File(context.getFilesDir(), assetName);
        if (file.exists() && file.length() > 0) {
            file.delete();
            //return file.getAbsolutePath();  //这行需要删除,因为只要读取一次模型就会在当前系统中将该模型暂存,
                                              //如果你第一次加载的模型有问题,之后没有都会加载有问题的模型,应该将模型缓存的模型删除。            				    	
        }

        try (InputStream is = context.getAssets().open(assetName)) {
            try (OutputStream os = new FileOutputStream(file)) {
                byte[] buffer = new byte[4 * 1024];
                int read;
                while ((read = is.read(buffer)) != -1) {
                    os.write(buffer, 0, read);
                }
                os.flush();
            }
            return file.getAbsolutePath();
        }
    }
    private static Bitmap arrayFlotToBitmap(List<Float> floatArray, int width, int height){

        byte alpha = (byte) 255 ;

        Bitmap bmp = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888) ;

        ByteBuffer byteBuffer = ByteBuffer.allocate(width*height*4*3) ;

        float Maximum = Collections.max(floatArray);
        float minmum = Collections.min(floatArray);
        float delta = Maximum - minmum ;

        int i = 0 ;
        for (float value : floatArray){
            byte temValue = (byte) ((byte) ((((value-minmum)/delta)*255)));
            byteBuffer.put(4*i, temValue) ;
            byteBuffer.put(4*i+1, temValue) ;
            byteBuffer.put(4*i+2, temValue) ;
            byteBuffer.put(4*i+3, alpha) ;
            i++ ;
        }
        bmp.copyPixelsFromBuffer(byteBuffer) ;
        return bmp ;
    }

这是一个将float数组转换为Bitmap的函数,截取自pytorch官方的issue里,https://github.com/pytorch/pytorch/issues/30655

5.实验配置

  • pytorch=1.10.0
  • android studio Arctic Fox(2020.3.1 Patch 4)
  • 手机型号:VivoX60tPro+(android 11) (只要满足build.gradle里要求的最低安卓版本应该都可以跑通,用小米9也成功部署)

6.手机端效果展示

基于Pytorch Mobile在安卓手机端部署深度估计模型_第2张图片
基于Pytorch Mobile在安卓手机端部署深度估计模型_第3张图片
基于Pytorch Mobile在安卓手机端部署深度估计模型_第4张图片
基于Pytorch Mobile在安卓手机端部署深度估计模型_第5张图片
看到最后如果还是觉得不够详细的话,可以回复我,考虑在B站上传完整的部署视频。

你可能感兴趣的:(pytorch,android,python,深度学习,计算机视觉)