#
#作者:韦访
#博客:https://blog.csdn.net/rookie_wei
#微信:1007895847
#添加微信的备注一下是CSDN的
#欢迎大家一起学习
#
之前的博客都是讲在python上怎么用tensorflow,那么,现在就来说说怎么将训练好的模型移植到安卓上。为了照顾不会安卓APP开发的童鞋,我将分为两讲,第一讲就是纯粹的python开发,主要是介绍怎么转换和验证模型。第二讲则是纯粹的安卓APP开发。这讲,我们先做些预备工作,先将模型转成tflite格式,并使用tflite将模型跑起来,以验证我们的tflite模型没问题。
环境配置:
操作系统:Win10 64位
显卡:GTX 1080ti
Python:Python3.7
TensorFlow:2.3.0
图像超分辨率(Image Super Resolution),是指将一幅低分辨率图像,恢复成高分辨率图像的技术。
如上图所示,图四表示原始的高清图,图一表示使用bicubic算法将图四缩小4倍后的低分辨率图,图二和图三则是使用SR算法处理以后恢复的高清图。
SRGAN是通过生成对抗网络(GAN)来实现超分辨率的技术,论文地址:https://arxiv.org/abs/1609.04802。有兴趣可以看看原论文,这个模型本身的知识不是今天的重点,这里简单说一下概念就可以了。
由于使用的是GAN技术,所以它由一个生成器Generator和一个判别器Discriminator构成,网络结构如下图所示。
由于GitHub上有很多开源的代码了( https://github.com/krasserm/super-resolution),所以我们就不重复造轮子了,直接拿人家训练好的模型来用就可以了。首先下载源码,
git clone https://github.com/krasserm/super-resolution.git
然后,下载它训练好的SRGAN模型:
https://drive.google.com/open?id=1u9ituA3ScttN9Vi-UkALmpO0dWQLm8Rv
上面的模型需要才能下载。不过放心,我博客末尾会给下载链接的。
下载模型后,解压到上面下载的源码的根目录下,如下图所示,
模型文件在weights文件夹下,如下图所示,
其中,gan_discriminator.h5和pre_generator.h5是判别器和预训练的生成器,我们不用它。gan_generator.h5是GAN训练出来的生成器,这个就是我们要用的模型。
下载完源码和模型后,我们先看看怎么在python中使用它。代码如下,
from model.srgan import generator
from utils import load_image, plot_sample
from model import resolve_single
import tensorflow as tf
import os
model = generator()
model.load_weights('weights/srgan/gan_generator.h5')
print(model.summary())
lr = load_image('./demo/0869x4-crop.png')
sr = resolve_single(model, lr)
plot_sample(lr, sr)
在运行代码之前,先修改一下utils.py文件,否则运行结果不会显示出来,将,
def plot_sample(lr, sr):
plt.figure(figsize=(20, 10))
images = [lr, sr]
titles = ['LR', f'SR (x{sr.shape[0] // lr.shape[0]})']
for i, (img, title) in enumerate(zip(images, titles)):
plt.subplot(1, 2, i+1)
plt.imshow(img)
plt.title(title)
plt.xticks([])
plt.yticks([])
改成
def plot_sample(lr, sr):
plt.figure(figsize=(20, 10))
images = [lr, sr]
titles = ['LR', f'SR (x{sr.shape[0] // lr.shape[0]})']
for i, (img, title) in enumerate(zip(images, titles)):
plt.subplot(1, 2, i+1)
plt.imshow(img)
plt.title(title)
plt.xticks([])
plt.yticks([])
plt.show()
然后再运行,运行结果如下,
可以看到,左边是低分辨率的图片,右边是通过SRGAN恢复的高分辨率图片(大小为原来图像的宽高分别乘以4)。可以看到,右边图片中,猫的猫发和眼睛比左边的图片清晰了很多。
由于到移植到安卓APP上,所以推荐使用tflite,因为它有实现手机上的GPU加速等优点。现在我们就将gan_generator.h5转成tflite,直接参考tensorflow官方的示例https://tensorflow.google.cn/lite/convert,运行下面的命令行即可,
tflite_convert --keras_model_file=weights\srgan\gan_generator.h5 --output_file=weights\srgan\gan_generator.tflite
运行结果:
出错了,可以在官方教程“Save and load”(https://tensorflow.google.cn/tutorials/keras/save_and_load#hdf5_format)中找到答案,
上面说了,如果自定义对象时,使用SavedModel就没问题,如果用HDF5保存模型还需要自定义get_config方法。既然HDF5那么麻烦,那么,一个最简单的方法就是,将HDF5先转成SavedModel后再转成tflite。使用下面的代码,将HDF5模型转成SavedModel模型,
from model.srgan import generator
from utils import load_image, plot_sample
from model import resolve_single
import tensorflow as tf
import os
model = generator()
model.load_weights('weights/srgan/gan_generator.h5')
print(model.summary())
model.save("saved_model")
运行上面的代码后,会生成一个saved_model的文件夹,文件夹内容如下,
接着,使用下面的命令将其转为tflite格式,
tflite_convert --saved_model_dir=saved_model --output_file=saved_model\gan_generator.tflite
运行结果,
可以看到saved_model文件夹下生成了gan_generator.tflite文件。
接下来,我们继续写python的代码来验证这个gan_generator.tflite是OK的。参考tensorflow官方教程(https://tensorflow.google.cn/lite/guide/inference),首先,导入tflite模型并打印它的输入和输出详情,代码如下,
import tensorflow as tf
import cv2
from utils import load_image, plot_sample
import numpy as np
interpreter = tf.lite.Interpreter(model_path="./saved_model/gan_generator.tflite", num_threads=4)
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print("input_details:", input_details, "\n\noutput_details:", output_details)
运行结果如下,
至少模型可以导入了。接着,模仿官方教程的python示例代码,设置模型输入和获取输出结果,代码如下,
import tensorflow as tf
import cv2
from utils import load_image, plot_sample
import numpy as np
interpreter = tf.lite.Interpreter(model_path="./saved_model/gan_generator.tflite", num_threads=4)
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print("input_details:", input_details, "\n\noutput_details:", output_details)
# Test the model on random input data.
image = load_image("demo/0869x4-crop.png")
image = np.asarray(image, np.float32)
image = np.expand_dims(image, axis=0)
print(image.shape)
interpreter.set_tensor(input_details[0]['index'], image)
interpreter.invoke()
# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data.shape)
运行结果,
input_details: [{'name': 'input_1', 'index': 0, 'shape': array([1, 1, 1, 3]), 'shape_signature': array([-1, -1, -1, 3]), 'dtype':
output_details: [{'name': 'Identity', 'index': 173, 'shape': array([1, 1, 1, 3]), 'shape_signature': array([-1, -1, -1, 3]), 'dtype':
(1, 118, 124, 3)
Traceback (most recent call last):
File "demo3.py", line 21, in
interpreter.set_tensor(input_details[0]['index'], image)
File "C:\Users\wilf\Anaconda3\lib\site-packages\tensorflow\lite\python\interpreter.py", line 407, in set_tensor
self._interpreter.SetTensor(tensor_index, value)
ValueError: Cannot set tensor: Dimension mismatch. Got 118 but expected 1 for dimension 1 of input 0.
出错,我们看错误提示,它说我们模型的输入维度不对,dimension 1需要的是1,但是我们提供的是118。
这个什么意思呢?通过input_details的log我们看到,模型输入的shape是
[1, 1, 1, 3],但是我们提供的输入图片的shape是(1, 118, 124, 3)。通过model/srgan.py里定义的SRGAN模型的网络结构可以看到,模型是输入shape其实是(None, None,3),也就是说其实模型的输入的宽高是不限制大小的,
所以,我们这里要通过resize_tensor_input函数重新设置模型的输入shape才行。对上面的代码进行修改,修改后的代码如下,
import tensorflow as tf
import cv2
from utils import load_image, plot_sample
import numpy as np
interpreter = tf.lite.Interpreter(model_path="./saved_model/gan_generator.tflite", num_threads=4)
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print("input_details:", input_details, "\n\noutput_details:", output_details)
# Test the model on random input data.
image = load_image("demo/0869x4-crop.png")
image = np.asarray(image, np.float32)
image = np.expand_dims(image, axis=0)
print(image.shape)
interpreter.resize_tensor_input(input_details[0]['index'], image.shape)
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0]['index'], image)
interpreter.invoke()
# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data.shape)
运行结果如下,
看样子是对了的,不过为了保险起见,我们把得到的结果显示成图片看看是不是达到预期效果。代码如下,
import tensorflow as tf
import cv2
from utils import load_image, plot_sample
import numpy as np
interpreter = tf.lite.Interpreter(model_path="./saved_model/gan_generator.tflite", num_threads=4)
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
print("input_details:", input_details, "\n\noutput_details:", output_details)
# Test the model on random input data.
image = load_image("demo/0869x4-crop.png")
image = np.asarray(image, np.float32)
image = np.expand_dims(image, axis=0)
print(image.shape)
interpreter.resize_tensor_input(input_details[0]['index'], image.shape)
interpreter.allocate_tensors()
interpreter.set_tensor(input_details[0]['index'], image)
interpreter.invoke()
# The function `get_tensor()` returns a copy of the tensor data.
# Use `tensor()` in order to get a pointer to the tensor.
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data.shape)
# show result
image = np.squeeze(image)
image = tf.clip_by_value(image, 0, 255)
image = tf.round(image)
image = tf.cast(image, tf.uint8)
output_data = np.squeeze(output_data)
output_data = tf.clip_by_value(output_data, 0, 255)
output_data = tf.round(output_data)
output_data = tf.cast(output_data, tf.uint8)
plot_sample(image, output_data)
运行结果,
嗯,达到预期结果,说明我们的模型转换是成功了,下一讲我们就可以开始进行安卓移植工作了。
https://mianbaoduo.com/o/bread/YZWVlZ5v