tf1和tf2转换方法不同,不同模型格式转换方法也不同,要看具体情况。tensorflow导出的模型不同格式具体是什么样,可以看这篇文章对号入座,https://www.cnblogs.com/biandekeren-blog/p/11876032.html。确定模型格式之后就可以确定需要用的函数了。不同模型格式转换到tflite对应的函数不同,各种函数详情见官网https://tensorflow.google.cn/lite/convert?hl=zh-cn
【注意这里有一个坑:tf1 和 tf2 分别支持的模型,在 models/research/object_detection/g3doc 下的 tf1_detection_zoo.md 和 tf2_detection_zoo.md 文件中有详细的列表介绍。这点是无法通过 tf2 compat v1 替代的。感谢博主「Abandon_first」的提醒,我这里因为已经有了pb文件不存在这个问题,但是以后转换总觉得会遇到,一定注意用tf1还是2转换】
我手里的模型是tf1导出的单个pb文件,对应的情况是 Frozen GraphDef转tflite,转化函数里要填的参数主要是模型相关信息,比如输入输出的name、shape之类的,可以用Netron查看。下载该软件之后可以直接打开pd文件对应的计算图。【PS:这个软件直接打开很多种格式的计算图,直接从官网下载就可以】后来我又发现input array和output array的值是导出时在一个文件中设置的,这篇博客里有解释https://blog.csdn.net/Abandon_first/article/details/118295485。这里还没有细细研究,等我自己重新训练的时候再好好看看。
tflite目前还不能支持所有的tensorflow操作,如果遇到报错:Some ops are not supported by the native TFLite runtime, you can enable TF kernels fallback using TF Select. See instructions: https://www.tensorflow.org/lite/guide/ops_select TF Select ops: …, …, …。官网给出的解决方案是用包含tf操作算子的tflite模型(TF Select),在转换之前加一行代码即可。下面这个官网给出的例子是savedmodel格式转tflite, Frozen GraphDef转tflite也是加这行代码。
import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
我最后完整的转换代码如下:
import tensorflow as tf
converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(
graph_def_file="D:\myhome\\fish_detection-master\models\\research\object_detection\\fish_inception_v2_graph2\\frozen_inference_graph.pb",
# both `.pb` and `.pbtxt` files are accepted.
input_arrays=['image_tensor'],
input_shapes={'image_tensor' : [None,None,None,3]},
output_arrays=['SecondStagePostprocessor/Softmax']
)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
]
tflite_model = converter.convert()
# Save the model.
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
tensorflow升到2.0之后,官方的github文件目录改了很多,有的文件夹合并了,比如contrib不见了,找教程或者加载模块的时候可能会有点混乱。2版本之后tensorflow lite部署的例子都在这里https://github.com/tensorflow/examples/tree/master/lite/examples【这里我干了一件蠢事hhh,子目录没有下载按钮,要回根目录下载】
关于运行官方demo遇到的一些问题:
1.原来的模型下载不下来,因为模型的源在国外,直接挂VPN没用,Android Studio需要配代理。于是干脆注释掉下载的Task,用本地转换好的tflite。注释这行代码就好了。【全局搜索qwq关掉繁简切换之后的Ctrl+Shift+F真好用qwq】
2.注意一贯坑人的tensorflow官方,又坑人了,给的官方例子的tflite没有metadata,博主「Abandon_first」转好了一份,当然后面换自己的模型还是逃不掉要自己转,不过可以先用现成转好的先把demo跑通。
锵锵~连接手机调试,虽然之前接触过一点安卓,但都是虚拟机,这是第一次真机调试。需要注意的有:
1.在Android studio里下载相应的SDK和驱动。
2.ADB的环境变量配置
3.用USB线连接手机,不要作死用双头type-c充电线
4.手机上开发者模式打开,USB调试打开【我傻了,我以为调成MTP传输就可以,结果调试没打开,谢谢方工呜呜,这离谱的错误我自己真的很难反应过来】,文件传输模式设置MTP。最后在设备管理器里更新驱动。完成之后在命令行和Android studio应该都可以检测到连接的安卓设备。
坑一:转metadata,直接只用tflite模型不行,需要把label_map.txt的信息也放进去。【到底是直接放标签还是json格式存疑,我直接放标签会报错。。。只好放了json格式,但是我的json内容来自pbtxt文件,迷惑ing】,官方给了转换代码
from tflite_support.metadata_writers import object_detector
from tflite_support.metadata_writers import writer_utils
ObjectDetectorWriter = object_detector.MetadataWriter
_MODEL_PATH = "D:\myhome\\fish_detection-master\models\\research\object_detection\colab_tutorials\model.tflite"
# Task Library expects label files that are in the same format as the one below.
_LABEL_FILE = "D:\myhome\\fish_detection-master\models\\research\object_detection\colab_tutorials\label_map.txt"
_SAVE_TO_PATH = "D:\myhome\\fish_detection-master\models\\research\object_detection\colab_tutorials\\fish_detect.tflite"
# Normalization parameters is required when reprocessing the image. It is
# optional if the image pixel values are in range of [0, 255] and the input
# tensor is quantized to uint8. See the introduction for normalization and
# quantization parameters below for more details.
# https://www.tensorflow.org/lite/convert/metadata#normalization_and_quantization_parameters)
_INPUT_NORM_MEAN = 127.5
_INPUT_NORM_STD = 127.5
# Create the metadata writer.
writer = ObjectDetectorWriter.create_for_inference(
writer_utils.load_file(_MODEL_PATH), [_INPUT_NORM_MEAN], [_INPUT_NORM_STD],
[_LABEL_FILE])
# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())
# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)
坑二:使用的模型里有tensorflow lite不支持的操作,之前转换的时候,就用了选择的tf算子,这里在安卓里也要安装对应的依赖,而且这样会导致模型变大。也可以自己转,但是有点麻烦,官网有但我没看明白qwq。坑二官网给出了解决方案https://www.tensorflow.org/lite/guide/ops_select但是还是感觉一言难尽。。。
坑三:tensorflow丧心病狂的版本更新又出现了,编译报错org.tensorflow:tensorflow-lite-support:0.1.0版本找不到。去仓库找对应新版本,链接如下:【再次感谢方工】https://search.maven.org/artifact/org.tensorflow/tensorflow-lite-support
目前是卡在看坑二,依然有不支持的算子,FlexTensorArrayV3,这玩意好像不支持。在github上发现有人遇到了和我一样的问题,但是看了一百多条讨论之后好像依然没有解决https://github.com/tensorflow/tensorflow/issues/40157我新开了issues并且给tensorflow开发人员发了邮件,蹲一个回复,挣扎一下,不行我就run去试onnx了qwq。害,就像师兄说的,过程精彩一些,不要急于结果,慢慢踩坑摸索啦~
更新:自己弄AAR文件的方法,值得一试,但是我没有liunx环境,sh是liunx命令。我我我再想想办法。