Keras是方便神经网络构建的一个高层API库,其本身并没有实现任何的矩阵运算以及GPU加速运算的功能,它只是为深度学习框架提供一个更方便但也更加生硬的接口,底层的运算都是由深度学习框架完成的(支持TensorFlow、Theano、MXNet)。
由于Keras的实用性,Google随后收购了Keras并将其功能全部集成到了TensorFlow中,TensorFlow2同样支持这些Keras的内容,均在tf.keras模块下(Keras是一个第三方库,使用pip安装,支持但不限于TensorFlow这类多种后端,而tf.keras是集成Keras功能的TensorFlow子模块,建议使用tf.keras而不是Keras这个第三方库)。
使用keras模块主要使用其五个主要模块,为datasets、layers、losses、metrics以及optimizers。
纵观神经网络的训练流程,大致流程如下。
显然,这整个过程可以理解为一个端到端的过程,输入张量,输出预测结果。至于中间的张量流动,是个可以进一步封装的整个模型,Keras实现了这个工作,在Keras中,处理这一切运算的模型是由Model类衍生的。
这个model实例主要实现四个方法,compile,fit, evaluate, predict,通过这四个封装的方法完成了很大代码量的神经网络构建和训练过程。
深度学习的网络模型结构五花八门,大体上就有卷积神经网络(CNN)、循环神经网络(RNN)、图卷积神经网络(GCN)等等。要想利用Keras这样并不灵活的顶层API实现具体的网络模型,预定义的网络层注定是不够的,自定义网络层的API就显得非常重要了。
在Keras中,要想实现自定义网络模型,需要用到三个至关重要的类,keras.Sequential,keras.layers.Layer,keras.Model。其中,Sequential是一个容器,它支持以列表的形式传入多个继承自keras.layers.Layer的实例对象,张量会从输入流经列表中的每一个Layer运算,所以本质上,Sequential表示的是张量的流动,Layer封装的是张量的运算过程。
最终,张量数据的入口是Sequential实例(其实,Sequential类是keras.Model的子类)或者keras.Model的实例,具体的运算流程由keras.layers.Layer完成,keras.Model封装了compile、fit等方法。**所以,自定义的层级运算继承自keras.layers.Layer,自定义的模型继承自keras.Model。**下面演示自定义Layer和Model,注意无论哪种方法,直接使用对象实例处理张量(如layer(x)
或者model(x)
),相当于调用call方法,经过实例内部的所有张量运算。
一方面,作为深度模型最终是需要落地的,因此保存训练好的参数或者干脆保存训练好的整个模型是非常重要的;另一方面,深度模型非常庞大,训练过程有时候很长(几周甚至几个月),很容易出现断电等情况,定时保存模型已经训练的参数是非常重要的。
Keras实现了很方便的模型保存接口,一般有两种方法save/load weights和save/load entire model,前者表示只保存模型参数,占用内存少,需要重新构建保存参数时的Model对象;后者表示保存整个模型对象,占用内存大,不需要重新构建相同的Model对象,load得到的就是Model对象。Keras保存的本地文件使用HDF5文件实现的。
但是,在工业界部署时,一般不使用效率较低的Python环境而是使用C++环境,此时保存Python对象就显得没有意义,而要保存一种通用协议的文件,TensorFlow2也提供了这种接口(所以说TensorFlow在工业界的地位无可替代)。事实上,只需要给出保存目录即可,保存的是一系列相关文件。
通过对CIFAR10数据集图片进行训练并预测类别。
"""
Author: Zhou Chen
Date: 2019/10/17
Desc: About
"""
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
from tensorflow import keras
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
def preprocess(x, y):
# [0~255] => [-1~1]
x = 2 * tf.cast(x, dtype=tf.float32) / 255. - 1.
y = tf.cast(y, dtype=tf.int32)
return x, y
batch_size = 128
# [50k, 32, 32, 3], [10k, 1]
(x, y), (x_val, y_val) = datasets.cifar10.load_data()
y = tf.squeeze(y)
y_val = tf.squeeze(y_val)
y = tf.one_hot(y, depth=10) # [50k, 10]
y_val = tf.one_hot(y_val, depth=10) # [10k, 10]
print('datasets:', x.shape, y.shape, x_val.shape, y_val.shape, x.min(), x.max())
train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.map(preprocess).shuffle(10000).batch(batch_size)
test_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
test_db = test_db.map(preprocess).batch(batch_size)
class MyDense(layers.Layer):
def __init__(self, inp_dim, outp_dim):
super(MyDense, self).__init__()
self.kernel = self.add_variable('w', [inp_dim, outp_dim])
def call(self, inputs, training=None):
x = inputs @ self.kernel
return x
class MyNetwork(keras.Model):
def __init__(self):
super(MyNetwork, self).__init__()
self.fc1 = MyDense(32 * 32 * 3, 256)
self.fc2 = MyDense(256, 128)
self.fc3 = MyDense(128, 64)
self.fc4 = MyDense(64, 32)
self.fc5 = MyDense(32, 10)
def call(self, inputs, training=None):
"""
:param inputs: [b, 32, 32, 3]
:param training:
:return:
"""
x = tf.reshape(inputs, [-1, 32 * 32 * 3])
# [b, 32*32*3] => [b, 256]
x = self.fc1(x)
x = tf.nn.relu(x)
# [b, 256] => [b, 128]
x = self.fc2(x)
x = tf.nn.relu(x)
# [b, 128] => [b, 64]
x = self.fc3(x)
x = tf.nn.relu(x)
# [b, 64] => [b, 32]
x = self.fc4(x)
x = tf.nn.relu(x)
# [b, 32] => [b, 10]
x = self.fc5(x)
return x
network = MyNetwork()
network.compile(optimizer=optimizers.Adam(lr=1e-3),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
network.fit(train_db, epochs=15, validation_data=test_db, validation_freq=1)
network.evaluate(test_db)
network.save_weights('ckpt/weights.h5')
del network
print('saved to ckpt/weights.h5')
network = MyNetwork()
network.compile(optimizer=optimizers.Adam(lr=1e-3),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
network.load_weights('ckpt/weights.h5')
print('loaded weights from file')
network.evaluate(test_db)
本文主要针对TensorFlow2中Keras这个高层API进行了简单使用上的介绍,其实,Keras是一个很值得学习的大型框架,感兴趣可以查看我的Keras专栏。具体代码上传至我的Github,欢迎star或者fork。