深度估计模型这里选择torch版本的Monodepth,代码地址:https://github.com/OniroAI/MonoDepth-PyTorch,文章链接:https://arxiv.org/abs/1609.03677。
建议在实现本文之前,先跑通torch的官方教程,https://github.com/pytorch/android-demo-app,本文建立在能跑通示例中语义分割模型的基础上。
整个网络设计中只使用pytorch定义的方法或python原生的语法,不能使用其他第三方框架如Numpy,Opencv。该例中,模型定义在models_resnet.py中,以Resnet18_md为例,需要修改的部分为:
1.代码中使用numpy实现的操作用原生的python库进行代替:
class conv(nn.Module):
def __init__(self, num_in_layers, num_out_layers, kernel_size, stride):
super(conv, self).__init__()
self.kernel_size = kernel_size
self.conv_base = nn.Conv2d(num_in_layers, num_out_layers, kernel_size=kernel_size, stride=stride)
self.normalize = nn.BatchNorm2d(num_out_layers)
def forward(self, x):
p = int(np.floor((self.kernel_size-1)/2)) #使用Numpy实现需要修改 np.floor ==》 math.floor ,即int(math.floor((self.kernel_size-1)/2))
p2d = (p, p, p, p)
x = self.conv_base(F.pad(x, p2d))
x = self.normalize(x)
return F.elu(x, inplace=True)
class maxpool(nn.Module):
def __init__(self, kernel_size):
super(maxpool, self).__init__()
self.kernel_size = kernel_size
def forward(self, x):
p = int(np.floor((self.kernel_size-1) / 2)) #使用Numpy实现需要修改 np.floor ==》 math.floor
p2d = (p, p, p, p)
return F.max_pool2d(F.pad(x, p2d), self.kernel_size, stride=2)
2.代码中过期的pytorch函数重新实现,因为Pytorch Mobile需要的pytorch版本很新,因此有些旧的实现已经在新版本中被修改:
self.udisp4 = nn.functional.interpolate(self.disp4, scale_factor=2, mode='bilinear', align_corners=True)
修改为:
udisp4 = nn.functional.interpolate(disp4, scale_factor=2., mode='bilinear', align_corners=True)
scale_factor在新版本中只能是浮点数,并且udisp4和disp4并没有在__init__()中定义为属性,因此在这里去掉self,否则Pytorch Mobile编译会报错。
3.修改输出:
return self.disp1, self.disp2, self.disp3, self.disp4
修改为
return disp1
这里输出四个视差是为了在多尺度下做Loss,在迁移到手机上时我们只用选择最大的尺度输出即可。
import torch
from PIL import Image
from torchvision import transforms
from torch.utils.mobile_optimizer import optimize_for_mobile
from utils import get_model
image = Image.open("O:\\xxx\\0.jpg") #读取一张图片用来测试输出尺寸是否满足预期
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input = preprocess(image) #转换成Tensor
model = get_model('resnet18_md', 3, True) #获取模型,模型的定义代码为models_resnet.py。
model.load_state_dict(torch.load("O:\\xxx\\monodepth_resnet18_001.pth")) #读取模型的预训练参数,预训练文件下载地址https://github.com/OniroAI/MonoDepth-PyTorch
input = input.unsqueeze(0)
output = model(input) #使用模型处理一张图片
print(output.shape) #测试尺度是否正常
model.eval()
scripted_module = torch.jit.script(model) #模型的转换!!!此处是重点,转换后的ptl模型就可以在安卓端运行
optimized_scripted_module = optimize_for_mobile(scripted_module) #针对移动端的特殊优化可以加快推理速度
# Export full jit version model (not compatible with lite interpreter)
scripted_module.save("monodepth.pt")
# Export lite interpreter version model (compatible with lite interpreter)
scripted_module._save_for_lite_interpreter("monodepth_scripted.ptl")
# using optimized lite interpreter model makes inference about 60% faster than the non-optimized lite interpreter model, which is about 6% faster than the non-optimized full jit model
optimized_scripted_module._save_for_lite_interpreter("monodepth_scripted_optimized.ptl") #根据官网描述,这种方式得到的模型推理速度最快比monodepth.pt快60%比monodepth_scripted.ptl快6%
package org.pytorch.imagesegmentation;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.os.SystemClock;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.ProgressBar;
import androidx.appcompat.app.AppCompatActivity;
import org.pytorch.IValue;
import org.pytorch.LiteModuleLoader;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class MainActivity extends AppCompatActivity implements Runnable {
private ImageView mImageView;
private Button mButtonSegment;
private ProgressBar mProgressBar;
private Bitmap mBitmap = null;
private Module mModule = null;
private int mImagename = 0;
public static String assetFilePath(Context context, String assetName) throws IOException {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
file.delete();
}
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
}
}
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
try {
mBitmap = BitmapFactory.decodeStream(getAssets().open(mImagename + ".jpg"));
} catch (IOException e) {
Log.e("DepthEstimation", "Error reading assets", e);
finish();
}
mImageView = findViewById(R.id.imageView);
mImageView.setImageBitmap(mBitmap);
final Button buttonRestart = findViewById(R.id.restartButton);
buttonRestart.setOnClickListener(new View.OnClickListener() {
public void onClick(View v) {
if(mImagename > 8) mImagename = 0;
try {
mBitmap = BitmapFactory.decodeStream(getAssets().open(mImagename + ".jpg"));
mImagename++;
mImageView.setImageBitmap(mBitmap);
} catch (IOException e) {
Log.e("DepthEstimation", "Error reading assets", e);
finish();
}
}
});
mButtonSegment = findViewById(R.id.segmentButton);
mProgressBar = (ProgressBar) findViewById(R.id.progressBar);
mButtonSegment.setOnClickListener(new View.OnClickListener() {
public void onClick(View v) {
mButtonSegment.setEnabled(false);
mProgressBar.setVisibility(ProgressBar.VISIBLE);
mButtonSegment.setText(getString(R.string.run_model));
Thread thread = new Thread(MainActivity.this);
thread.start();
}
});
try {
mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), "monodepth_scripted_optimized.ptl"));
} catch (IOException e) {
Log.e("DepthEstimation", "Error reading assets", e);
finish();
}
}
@Override
public void run() {
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(mBitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
final float[] inputs = inputTensor.getDataAsFloatArray();
final long startTime = SystemClock.elapsedRealtime();
IValue outTensors = mModule.forward(IValue.from(inputTensor));
final long inferenceTime = SystemClock.elapsedRealtime() - startTime;
Log.d("DepthEstimation", "inference time (ms): " + inferenceTime);
System.out.println(inferenceTime);
final Tensor outputTensor = outTensors.toTensor();
final float[] intValues = outputTensor.getDataAsFloatArray();
int width = mBitmap.getWidth();
int height = mBitmap.getHeight();
ArrayList<Float> arralist = new ArrayList<>();
for (int i = 0 ; i< intValues.length ; i++){
arralist.add(intValues[i]);
}
final Bitmap bitmap = arrayFlotToBitmap(arralist, width, height);
runOnUiThread(new Runnable() {
@Override
public void run() {
mImageView.setImageBitmap(bitmap);
mButtonSegment.setEnabled(true);
mButtonSegment.setText(getString(R.string.segment));
mProgressBar.setVisibility(ProgressBar.INVISIBLE);
}
});
}
private static Bitmap arrayFlotToBitmap(List<Float> floatArray, int width, int height){
byte alpha = (byte) 255 ;
Bitmap bmp = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888) ;
ByteBuffer byteBuffer = ByteBuffer.allocate(width*height*4*3) ;
float Maximum = Collections.max(floatArray);
float minmum = Collections.min(floatArray);
float delta = Maximum - minmum ;
int i = 0 ;
for (float value : floatArray){
byte temValue = (byte) ((byte) ((((value-minmum)/delta)*255)));
byteBuffer.put(4*i, temValue) ;
byteBuffer.put(4*i+1, temValue) ;
byteBuffer.put(4*i+2, temValue) ;
byteBuffer.put(4*i+3, alpha) ;
i++ ;
}
bmp.copyPixelsFromBuffer(byteBuffer) ;
return bmp ;
}
}
实现参考了pytorch官方的语义分割实例,https://github.com/pytorch/android-demo-app/tree/master/ImageSegmentation。其中有两个主要修改:
public static String assetFilePath(Context context, String assetName) throws IOException {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
file.delete();
//return file.getAbsolutePath(); //这行需要删除,因为只要读取一次模型就会在当前系统中将该模型暂存,
//如果你第一次加载的模型有问题,之后没有都会加载有问题的模型,应该将模型缓存的模型删除。
}
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
}
}
private static Bitmap arrayFlotToBitmap(List<Float> floatArray, int width, int height){
byte alpha = (byte) 255 ;
Bitmap bmp = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888) ;
ByteBuffer byteBuffer = ByteBuffer.allocate(width*height*4*3) ;
float Maximum = Collections.max(floatArray);
float minmum = Collections.min(floatArray);
float delta = Maximum - minmum ;
int i = 0 ;
for (float value : floatArray){
byte temValue = (byte) ((byte) ((((value-minmum)/delta)*255)));
byteBuffer.put(4*i, temValue) ;
byteBuffer.put(4*i+1, temValue) ;
byteBuffer.put(4*i+2, temValue) ;
byteBuffer.put(4*i+3, alpha) ;
i++ ;
}
bmp.copyPixelsFromBuffer(byteBuffer) ;
return bmp ;
}
这是一个将float数组转换为Bitmap的函数,截取自pytorch官方的issue里,https://github.com/pytorch/pytorch/issues/30655