[Keras-tensorflow-backend]技巧大全

前言

keras(tensorflow-backend)可以看作是tensorflow的高级封装。因此在使用keras时,依然需要注意其背后的Graph于Session。在调用keras.layers.*时,keras会调用tensorflow在一个Graph上建图。通过keras.backend.set_session与keras.backend.get_session,可以设置和获取keras背后使用的sess。因此,keras可以完成转换为tensorflow,但tensorflow不能转换为keras。

本文将详细描述:

  • 如何混合使用tensorflow与keras
  • 迁移学习里的keras.backend.learning_phase()。
  • 如何将网络结构用.txt保存与读取。

涉及到的API包括:

  • keras.backend.get_session
  • keras.backend.set_session
  • keras.backend.set_learning_phase()
  • keras.layers.batchnormalization.call()
  • from keras.models import Model
  • from keras.models import model_from_json

keras与tensorflow的混合使用

keras作为tensorflow的高级封装,从层的角度定义网络,可以非常方便的建立、保存、载入、修改网络模型与权重,但是灵活性就比较差。tensorflow作为GPU运算框架,则具备很强的灵活性,但从Variable定义网络过于复杂。于是,考虑keras与tensorflow的混合使用:keras管理模型,而tensorflow去计算一些自定义的过程,如特征图可视化。

混合使用时,一般先是keras定义好模型,tensorflow进一步在已有的模型上建图。首先保证图(Graph)是一致的,则可以通过tf.Graph.get_tensor_by_name()去获取keras定义模型中的Variable与tensor。或者Conv2D.weights也可以获取Variable的指针。然后就可以进一步建图了。

在通过会话运行图前,应保证会话一致。可以先建立一个sess,再传给keras来保证会话的一致,或者不要新建sess,通过keras.backend.get_session()获取keras自建的sess:

# sess一致方法一
sess = tf.Session()
keras.backend.set_session(sess)
# keras build Model
# tensorflow build graph
sess.run()


# sess一致方法二
# keras build Model
# tensorflow build graph
sess = keras.backend.get_session()
sess.run()

sess的不一致会导致模型明明载入了,run运行时却好像没载入!

keras迁移学习

在使用keras进行迁移学习时,往往需要冻结部分层。此时需要特别注意batch normalization层的设置是否正确。

首先,我们知道,BN层存在训练和测试模式。训练时使用batch mean和batch variance,同时更新moving mean和moving variance;测试时使用moving mean和moving variance。

于是这里存在两个需要控制的事情:1)使用batch mean/var还是moving mean/var。这件事在keras由learning phase控制。2)是否更新moving mean和moving variance。这件事则通过设置trainable来控制。

因此,要正确地用keras完成迁移学习,就需要正确地设置每一个bn层的learning phase和trainable。

如下控制bn层的trainable

bn_1 = keras.layers.BatchNormalization(...)
bn_1.trainable = True # False

如下控制learning phase。注意,需要在定义网络时设置

bn_1 = keras.layers.BatchNormalization(...)
output = bn_1(inputs, training=None) # output = bn_1(inputs, training=False|True)

因此,对同一个bn层,1)要么训练此层:trainable=True,training=None。2)要么冻结此层:trainable=False,training=False。注意,这里training若设置为True,那么在测试的时候会出现问题,使用batch mean和batch variance。

你可能感兴趣的:(笔记,tensorflow,深度学习,人工智能)