详细API
导入标准库
import tensorflow as tf
# 导入 keras 模型,不能使用 import keras,它导入的是标准的 Keras 库
from tensorflow import keras
from tensorflow.keras import layers # 导入常见网络层类
网络容器 Sequential 将多个网络层封装成一个大网络模型,只需要调用网络模型的实例一次即可完成数据从第一层到最末层的顺序运算
# 导入 Sequential 容器
from tensorflow.keras import layers, Sequential
network = Sequential([ # 封装为一个网络
layers.Dense(3, activation=None), # 全连接层
layers.ReLU(),#激活函数层
layers.Dense(2, activation=None), # 全连接层
layers.ReLU() #激活函数层 ])
x = tf.random.normal([4,3])
network(x) # 输入从第一层开始,逐层传播至最末层
Sequential 容器也可以通过 add()方法继续追加新的网络层,实现动态创建网络的功能
当需要很多层时这样建网络:
layers_num = 2 # 堆叠 2 次
network = Sequential([]) # 先创建空的网络
for _ in range(layers_num):
network.add(layers.Dense(3)) # 添加全连接层
network.add(layers.ReLU())# 添加激活函数层
network.build(input_shape=(None, 4)) # 创建网络参数
network.summary() #打印出网络结构 和参数量
通过 compile 函数指定网络使用的优化器对象,损失函数,评价指标等:
# 导入优化器,损失函数模块
from tensorflow.keras import optimizers,losses
# 采用 Adam 优化器,学习率为 0.01;采用交叉熵损失函数,包含
Softmax network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'] )# 设置测量指标为准确率
通过 fit()函数送入待训练的数据和验证用的数据集:
# 训练集为 train_db,验证集为 val_db,训练 5 个 epochs,每 2 个 epoch 验证一次
# 返回训练信息保存在 history 中
history = network.fit(train_db, epochs=5, validation_data=val_db,
validation_freq=2)
history.history # 打印训练记录
# 加载一个 batch 的测试数据
x,y = next(iter(db_test))
print('predict x:', x.shape) out = network.predict(x) # 模型预测
print(out)
#简单测试模型
network.evaluate(db_test) # 模型测试
有三种方法:
最轻量级的一种方式,直接保存网络张量参数到文件上
# 保存模型参数到文件上
network.save_weights('weights.ckpt')
del network # 删除网络对象
# 重新创建相同的网络结构
network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'] )
# 从参数文件中读取数据并写入当前网络
network.load_weights('weights.ckpt')
文件中保存的仅仅是参数张量的数值,并没有其 他额外的结构参数。但是它需要使用相同的网络结构才能够恢复网络状态,因此一般在拥 有网络源文件的情况下使用。
仅仅需要模型参数文件即可恢复出网络模型的方式,不需要网络源文件的条件下,通过 keras.models.load_model(path)即可恢复网络结构和网 络参数
# 保存模型结构与模型参数到文件
network.save('model.h5') #注意保存使用的函数不一样
print('saved total model.')
del network # 删除网络对象
#此时通过 model.h5 文件即可恢复出网络的结构和状态:
# 从文件恢复网络结构与网络参数
network = tf.keras.models.load_model('model.h5')
# 保存模型结构与模型参数到文件
tf.keras.experimental.export_saved_model(network, 'model-savedmodel')
print('export saved model.')
del network # 删除网络对象
# 从文件恢复网络结构与网络参数
network = tf.keras.experimental.load_from_saved_model('model-savedmodel')
方便各个平台能够无缝对接训练好的网络模型。