目前已经出现的很多的可以在移动端使用的框架caffe2、TensorFlow Lite、CNNdroid、百度MDL(mobile-deep-learning)、腾讯NCNN、亚马逊 MXNet等,我们可以直接在移动端进行开发,但是由于手机运行速度的限制我们将模型的训练过程还是放置在电脑上进行,在手机上我们可以调用训练好的模型,完成测试部分。这就需要我们在训练模型时将模型保存为Android可以调用的格式,但是对于已经训练好的模型我们就需要进行模型格式的转化。
下面我将通过我的例子讲解一下TensorFlow模型转化的过程:
我的程序使用的是TensorFlow框架,移动端使用Android studio进行开发。而想要在移动端使用TensorFlow的代码,可以通过在新建项目路径./app/build.gradle里添加代码下面一行代码实现:
implementation 'org.tensorflow:tensorflow-android:+'
但是这种的对于模型的调用需要很多的模型文件,而已经出现的移动端深度学习框架TensorFlow Lite可以更加方便的进行调用,所以我们使用这一框架,将已经训练好的模型转化为.tflite格式,然后在Android studio中配置环境,调用模型经行测试。其中关键部分是转化的过程。
转化过程:
(一)生成PB文件
我们需要做的是将训练的.pb模型转化为.tflite模型,而对于任意一个已经训练完成的TensorFlow模型都会生成几个文件如图1所示,.meta文件保存了当前图结构;.index文件保存了当前参数名;.data文件保存了当前参数值。但由于保存模型时候的格式问题,PB模型不一定存在,所以需要生成PB模型文件。
图1:模型文件
上图中根据命名习惯的不同结果可能会有略微的差别,前面的相同部分在有些文件中是synthia_498000.ckpt;而在大部分的程序中,PB模型文件是不存在的,所以我们需要利用这些文件生成PB模型;
使用代码实现:
saver = tf.train.import_meta_graph("./pre_trained_model/synthia_498000.meta", clear_devices=True)
output_nodes = ['planeMask_and_planeParam_prediction/depth_net/mask/ResizeBilinear', 'planeMask_and_planeParam_prediction/depth_net/mask/ResizeBilinear_1', 'planeMask_and_planeParam_prediction/depth_net/mask/ResizeBilinear_2',
'planeMask_and_planeParam_prediction/depth_net/param/mul'] //这里的输出节点必须设定
with tf.Session(graph=tf.get_default_graph()) as sess:
input_graph_def = sess.graph.as_graph_def()
saver.restore(sess, "./pre_trained_model/synthia_498000") //图1中的模型文件文件
output_graph_def = tf.graph_util.convert_variables_to_constants(sess, input_graph_def,output_nodes)
with open("planer.pb", "wb") as f://生成的PB文件名称
f.write(output_graph_def.SerializeToString())
注意:这里关键点在于输出节点的选择,输出节点信息可以通过.ckpt文件利用代码输出,也可以通过使用tensorboard打开graph.pbtxt文件查看。
1、 代码实现:
checkpoint_path = os.path.join( "D:\\planerecover\\pre_trained_model\\synthia_498000")
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print( "'"+key+"'"+op)
2、 Tensorboard查看:
在程序目录下找到graph.pbtxt文件所在的目录,打开命令提示符,路径到文件所在的目录下运行命令:tensorboard --logdir=./ 得到如图2所示结果:
图2:tensorboard打开路径
复制框中网址并打开得到所有节点图,如图3、图4所示:
图3:节点信息图
图4:输入输出节点确认
从图中找到相对应的模型输出节点,这里的输出节点可能有多个;对于以上两种方法个人觉得第二种方法可能更好一点,第一种由于每个人模型节点命名的习惯不同,导致节点信息不容易查看,第二种可以在图中直观的看到节点的流程图,以及每个节点的信息;如果我们在训练脚本中给了它一个名字,这样就很容易了。如果没有,则需要使用Tensorboard并为其找到自动生成的名称。以上我们就可以实现PB文件的生成。
(二)PB文件到TFLITE文件的转化
这里也可以通过两种方法实现:
1、 通过代码实现:
path="planer.pb"
outputs=output_nodes //同上面步骤相同,输出节点信息
inputs=['planeMask_and_planeParam_prediction/depth_net/cnv1/Conv2D'] //输入节点信息
input_shape={"planeMask_and_planeParam_prediction/depth_net/cnv1/Conv2D":[4,96,160,32]}
//输入的shape
converter = tf.lite.TFLiteConverter.from_frozen_graph(path, inputs,outputs,input_shape)
//这里不同的TensorFlow版本方法不同,我是用的版本是TensorFlow1.12,
//TensorFlow lite 从TensorFlow1.7版本开始支持。
//tf.contrib.lite.toco_convert
//tf.lite.toco_convert
converter.post_training_quantize = True
tflite_model=converter.convert()
open("model_pb.tflite", "wb").write(tflite_model)
注意:如果我们项目已经生成PB文件,那么节点信息查看我们也可以通过PB文件查看:
gf = tf.GraphDef()
gf.ParseFromString(open('planer.pb','rb').read())
for n in gf.node:
print ( n.name +' ===> '+n.op )
2、 通过命令行实现:
toco \
--input_file=/tmp/mnist_graph_def_with_ckpts/opt_mnist_graph.pb \ //输入文件:pb文件路径
--input_format=TENSORFLOW_GRAPHDEF \ //默认
--output_format=TFLITE \ //默认
--inference_type=FLOAT \ //默认
--input_type=FLOAT \ //默认
--input_arrays=input_tensor \ //输入节点(关键参数)查找方式同输出节点一致
--output_arrays=softmax_tensor \ //输出节点 (关键参数)
--input_shapes=1,28,28,1 \ //输入节点shape,可以在tensorboard中查看
--output_file=/tmp/mnist_graph_def_with_ckpts/mnist.tflite //输出文件
注意:这里的输入文件我使用的是第一步中直接生成的PB文件,在tensorboard中打开graph.pbtxt,我们可以看到训练脚本中生成的每个图层。然后开始了解哪些图层对于推断是必需的,哪些图层可以丢弃掉的。 绿线框起来的所有内容都用于在训练过程中调整权重。同样,input_tensor之前的所有内容也是不必要的。所以我们可以通过使用源码安装的TensorFlow下的freeze_graph.py将模型文件和权重数据整合在一起并去除无关的Op。
具体实现:
1、 代码实现:
tf.reset_default_graph()
saver = tf.train.import_meta_graph("D:\\planerecover\\pre_trained_model\\synthia_498000.meta")
with tf.Session() as sess:
saver.restore(sess, model_path)
tf.train.write_graph(sess.graph_def, './pb_model', 'model_ResNet_L152.pb') //生成PB文件
freeze_graph.freeze_graph('pb_model/model_ResNet_L152.pb',//输入文件
'',
False,
model_path,
'depth_net/param/param/weights',
'save/restore_all',
'save/float:0',
'pb_model/frozen_model_ResNet_L152.pb', //输出文件
False,
"") //冻结优化模型,其中参数有11个
2、 命令行实现:
python tensorflow/python/tools/free_graph.py \
–input_graph=some_graph_def.pb \ //这里的pb文件是用tf.train.write_graph方法保存的
–input_checkpoint=model.ckpt.1001 \
//这里若是r12以上的版本,只需给.data-00000….前面的文件名,
//如:model.ckpt.1001.data-00000-of-00001,只需写model.ckpt.1001
–output_graph=/tmp/frozen_graph.pb
–output_node_names=output_nodes //输出节点名称
(三)Android Studio 调用模型文件
配置好Android环境,安装SDK,NDK,新建一个Android项目,将生成的.tflite文件放在项目的/app/src/main/assets/ 文件夹下,修改一些配置文件:
1、在app目录下的build.gradle配置文件加上以下配置信息:
在dependencies下加上包的引用,第一个是图片加载框架Glide,第二个就是我们这个项目的核心TensorFlow Lite:
implementation ‘com.github.bumptech.glide:glide:4.3.1’
implementation ‘org.tensorflow:tensorflow-lite:0.0.0-nightly’
2、在android下加上以下代码,这个主要是限制不要对tensorflow lite的模型进行压缩,压缩之后就无法加载模型了:
//set no compress models
aaptOptions {
noCompress "tflite"
}
3、新建Java类,进行功能实现。
总结
将训练好的TensorFlow模型在移动端调用,首先需要生成可以被调用的.tflite文件,然后再在Android中开发调用模型,转化过程注意输入输出节点的选择。