经过前两篇文章的开发,咱们今天终于要进入令人激动的上线篇了。(最近刚刚发布的TensorFlow lite其实也是部署上线的工具集之一)话说我在学习TensorFlow的时候,发现这部分的教程是尤其少。大部分教程都是先上来教一个回归,再来一个CNN,在来几篇保存模型和TensorBoard就结束了。我们这篇文章就来重点聊一聊部署上线。
这篇文章会被分成四个部分,第一部分继续上篇文章,聊一聊第四步调参;第二部分聊一聊训练中的模型保存和载入;第三部分,介绍TensorFlow Serving;第四部分 就是最重要的部署上线流程。
Part 1 训练中的调参
还是先回顾下我们再三提及的解题思路:
第一步:将问题分解成输入(x)到输出(y)这样的结构,如Discuz验证码的输入是图片,输出是四个字符的字符串
第二步:找到很多同时包含输入输出的数据,比如很多有识别结果的验证码图片
第三步:针对不同问题,找到算法大神们的已经定义好的算法并实现成代码
第四步:尝试使用这个算法训练这些数据,如果效果不好,算法中有一些参数可以手动调整,至于怎么调,可以参考前人经验,也可以自己瞎调积累经验。
第五步:写一个程序载入模型,接受一个新的输入值,通过模型计算出新的输出值。
前两篇文章已经走完了前三步,那么我们现在来看的第四步。不负责任的讲,其实机器学习工程师大部分时间都是在调参。毕竟,大神都想好的算法就几个,但是却留了很多参数来调整。我们总还是得体现出自己的价值是不是,那我们来看看一般都有哪些参数值得调整(大家最好对照代码来看,这个是熟悉代码最好的方式):
1.图片最终压缩的长宽
image_resized = tf.image.resize_images(image_gray, [48, 48],tf.image.ResizeMethod.NEAREST_NEIGHBOR)
这一段是我们前面代码中压缩图片的逻辑,那么最终图片压缩到多大是最合适的呢?太大的训练的慢,内存占用高;太小了又会丢失重要信息。一般来说,我们挑一个样本图片,压缩完之后用肉眼看一下,如果你自己肉眼还能识别的话,那就是ok的。
2.神经网络层数和节点数
看了两节课,大家应该大体能知道我们的神经网络是由一层一层的节点组成的,比如这就是一层,这一层实现的是卷积层:
image_x = tf.reshape(image_resized_float,shape=[-1,48,48,1])
conv1 = tf.layers.conv2d(image_x, filters=32, kernel_size=[5, 5], padding='same')
norm1 = tf.layers.batch_normalization(conv1)
activation1 = tf.nn.relu(conv1)
pool1 = tf.layers.max_pooling2d(activation1, pool_size=[2, 2], strides=2, padding='same')
hidden1 = pool1
这样的也是一层,这一层是叫做全连接层:
hidden3 = tf.layers.dense(flatten, units=1024, activation=tf.nn.relu)
大家可以看到我们这些层里面有非常多的参数,比如filters,kernel_size,pool_size,strides,units等等,这些参数到底应该写啥呢。幸运的是,一般来说大神都给了我们一些常规建议,比如kernel_size都是[3,3]或者[5,5],比如pool_size都[2,2],max_pooling2d的strides都是2等等,这些如果咱自己不是大神就别瞎调了。
我们可以发挥一些主观能动性的是filters和units。当然这里我这里也是一些常规值,大家可以在这些常规值上做一些2的倍数的调整(好像也没有什么特别的原因,咱姑且认为是程序员的强迫症吧),当然也不是完全瞎调,首先数量不能太大,因为太大之后,一方面最后模型极大,另外也可能会造成过拟合严重,训练也慢。当然也不能太小,不然可能就拟合不上,准确率上不去。具体大家调整中自行感受,比如在调整filters的时候可以猜想下比如自己要来选择特征会选择几个(所谓特征选择就是有多少个转角,有多少个弧度这种,每一个特点算一个特征)。
另外还有一个参数就是到底要写几层,这个常规来说大神也是定义好的,而且一般把层数都写到算法名字上去了,没啥可发挥的空间。不过强行发挥的话就是如果你觉得你要识别的这个东西比大神举的例子简单很多,就可以少来几层(比如我们现在实现的这个简化版CNN,原版好像是6层,被我删成了4层)。
3.其他的一些(超)参数
还有一些比较常规的参数也可以调整,比如最常见的学习速率(代码里的1e-4):
tf.train.AdamOptimizer(1e-4)
这个不建议大家初学的时候调整,除非发现准确率一直忽高忽低的情况,可以考虑调小一些。
调参->训练->调参->训练,经过几个回合,理论上应该就可以获得一个训练速度符合要求的训练代码了,大家就可以直接启动训练了。
Part 2 模型的保存和载入
事实上,一般的机器学习最终输出的东西都是一个模型。所谓模型,就是一堆变量的值(比如刚刚那个神经网络层里面每个节点中的变量),那么我们如何保存我们训练出来的模型呢?
TensorFlow提供了非常方便的实现,这个应该写过代码的人看一眼就懂了:
saver = tf.train.Saver()
saver.save(sess,'my-model')
注意,我们是通过传入session来保存模型的。(神箭手线上会自动对模型进行起名,上传,因此只需要传入session一个参数即可)
对应的,载入模型也异常简单:
saver.restore(sess,'my-model')
这句的意思就是把my-model这个模型中的所有变量赋值给session中的变量,咱就可以继续跑了。(同样,神箭手自动对模型进行处理,只需要传入session即可)
当然,TensorFlow还提供了另外一种保存和载入模型的机制,这个就是我们第三部分要讲到的主要用于部署上线的模型保存。
Part 3 TensorFlow Serving环境
事实上,TensorFlow对模型的部署上线提供了完善的工具集,其中本篇文章重点要讲到的是TensorFlow Serving,当然类似TensorFlow Lite是另一个方面的重要工具,不过今天这里我们先不做讨论。
那么什么是TensorFlow Serving呢?实际上我们使用数据训练出了最终模型,是要对未知数据进行预测的,那么我们怎么预测呢?当然我们可以把我们的python代码直接执行一遍就行了,不过通常这不符合生产环境的要求(比如速度不够,没有版本控制等等),可敬可爱的TensorFlow工程师急大家之所急,为大家写了一套C++的包含版本控制等重要功能的运行环境:TensorFlow Serving(当然他们还写了一套手机上的运行环境)。那么大家只需要把训练好的模型保存出来,放到这个环境中去执行即可,那么这里就有几个问题出现了:
1.我们得显式告诉这个环境我们输入和输出是什么(就是在我们定义好的图上调出输入节点和输出节点分别是什么)
2.我们在整个预测的这个流程中不能出现纯粹python的代码(比如TensorFlow中的py_func这段代码就不能出现在Serving环境中)
3.我们要注意训练环境和Serving环境的差异性
我们一条一条展开讲:
1.我们得显式告诉这个环境我们输入和输出是什么
也就是说我们需要在保存模型的时候,同时申明我们的输入输出节点分别是什么,这样环境自动会把我们输入的值赋值给指定的节点,然后输出我们指定的输出节点的值就结束了,我们看下代码:
#指定输入输出值
prediction_signature = def_utils.predict_signature_def({'image_base64': x},{'label': predict_join})
可以看到,我们指定的输入值是图中的x,给了一个名字叫image_base64,输出值是图中的predict_join,给的名字是label。如果大家仔细看代码会发现这个predict_join在上篇文章的代码中没有,没关系,一会会给大家完整代码。这个predict_join的计算相对复杂,我们就不在这里展开讲了,目的就是基于模型输出的向量结果,转换成最终我们需要的验证码识别结果的字符串。
接着我们将我们的模型和这个输入输出申明一起保存起来:
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
prediction_signature,
},
legacy_init_op=legacy_init_op)
builder.save()
注意,同样的这里的export_dir是保存路径,神箭手上可以不传入。
2.我们在整个预测的这个流程中不能出现纯粹python的代码
还记得我前两篇文章一直在说的,为什么我们不适用python的base64,python的Image库,而一定要使用TensorFlow的库呢?是因为在TensorFlow Serving的环境中是不能执行python代码的,输入节点和输出节点都得是预定义的图上的节点,因此无论是直接的python代码还是py_func导入的python代码都是无法执行了,那么这时候我们有两个选择:第一个是图像处理,base64等用python的库,而在TensorFlow Serving之前在写一个python的client去调用他,这样也可以,但是显然更复杂;另一个选择就是我们把能写进预定义图的逻辑都写进预定义图中,这样我们就可以去掉python的client直接调用了。
3.我们要注意训练环境和Serving环境的差异性
虽然都是TensorFlow提供的环境,但是由于训练环境是python实现的(虽然底层还是c++,但是毕竟还是有不同的部分),而TensorFlow Serving环境是c++实现的,因此难免部分函数的实现有差异。在实际使用中,我们发现,在Serving中tf.decode_base64的返回值的形式和训练环境有差异,因此我们需要在调用完decode_base64后再调用如下一句来统一训练和Serving环境:
image_bin_reshape = tf.reshape(image_bin,shape=[-1,])
当然,整个代码中肯定还是会有其他差异性,大家在遇到的时候不要觉得奇怪,毕竟谷歌工程师也是人,也有疏漏的地方。
训练+模型保存的完整代码已经上传了Github:
https://github.com/ShenJianShou/tensorflow_tutorial/blob/master/lession-1/python/tensorflow.py
Part 4 模型部署上线
好了,TensorFlow Serving环境的特殊性说完了,需要保存的用于Serving的模型我们也保存好了,下面该正式上线了。这里我们依然介绍两种方案,分别是神箭手方案和线下方案:
1.神箭手方案
说真的,整个TensorFlow Serving的部署虽然不复杂,但是依然有些繁琐,而且灵活性不佳。更要命的就是调用时候,由于需要将输入输出的变量打包成protobuff的格式,因此无论如何也要单独写client来处理就显得更加麻烦。
我们的工程师急大家之所急,急TF工程师之不急,为大家提供了完全集成的Serving环境,大家在运行完上面的模型保存之后,保存的模型会自动进入模型管理中,大家只需要点击右侧的启动Serving按钮即可:
神箭手会自动将模型传输给TensorFlow Serving环境,并且自动解析输入输出变量。然后自动生成一个http的api接口供大家调用:
2.线下方案
如果不在神箭手上训练模型,只需要自行搭建TensorFlow Serving的环境即可,环境的搭建这里不多赘述了,大家只需要按照官方教程一步一步走即可:
https://www.tensorflow.org/serving/
特别要提的依然是这个client的问题,很多人可能被这个弄的很晕。这里简单给大家普及几个概念。
首先TensorFlow Serving使用的是GRPC+ProtoBuff方案,这个方案和我们传统使用的HTTP+JSON(或XML)是不兼容的。之所以使用这个方案,有大概3个考虑:1 这个方案解析快;2 这个方案传输存储消耗都小;3 谷歌顺便推广自家其他的框架。(事实上在谷歌内部几乎所有的存储都是采用protobuff的结构的,所以在TensorFlow上这样使用也不奇怪了)因此大家在自己调用的时候,需要通过先将变量打包成protobuff的结构,在通过编译后的grpc接口去调用。具体实现可以参考官方提供的client接口:
https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/mnist_client.py
虽然是官方的例子是python的,我们依然可以用其他语言如JAVA或者C++来实现client,像是神箭手帮大家打包的接口就是通过PHP来实现的Client。不过如果使用非python语言,都需要下载TensorFlow里的proto文件,并在指定的语言上编译好才可以。
好了,无论使用什么方案。当我们训练完之后,将模型导入Serving环境中,我们就完成这样一个Discuz验证码识别接口的编写和实现,从此以后再也不需要调用打码平台啦。
=======================================================
最后再说下:本系列教程的目的是帮助大家入门TensorFlow编程以及了解在机器学习应用在实际问题中时的具体实践。希望大家能够举一反三,没必要局限在Discuz验证码识别这个问题上,包括我们后面也会继续推进其他验证码的识别的。
好了,这个系列第一篇文章到此结束,因为是第一篇所以相对讲的细一些(一篇分成了三篇来讲),后面我们将以一篇解决一个问题来展开。之前文章聊过的细节就不会在去深入探讨了。