一、引入相关包
# coding: utf-8
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import os
二、设置网络维度
#设置网络层维度
M = 50 #输入数据维度
N = 32 #隐藏层维度
L = 10 #输出分类数量
三、声明Sequential模型
# Sequential模型堆叠
# Y = σ(Wx+B)σ=relu
model = tf.keras.Sequential()
四、添加网络隐含层
model.add(layers.Dense(N, activation='relu',kernel_initializer='RandomUniform',
bias_initializer="RandomNormal",name="layer_1"))
# Y = σ(Vy)σ=softmax
# 不配置时使用默认初始化
model.add(layers.Dense(L, activation='softmax',name="layer_2"))
说明:
激活函数可选配置: softmax、elu、softplus、softsign、relu、tanh、sigmoid、hard_sigmoid、linear
kernel_initializer和bias_initializer,可选参数如下。
#常数:zero、zeros(默认偏置项配置)、Zeros、one、ones、Ones、constant、Constant
#均匀分布:uniform、random_uniform、RandomUniform
#正态分布:normal、random_normal、RandomNormal
#截断的正态分布:truncated_normal、TruncatedNormal
#标准化:identity、Identity
#正交:orthogonal、Orthogonal
#正态化Glorot,即Xavier:glorot_normal、GlorotNormal
#Glorot均匀分布(默认权重配置):glorot_uniform、GlorotUniform
默认配置:bias_initializer='zeros',kernel_initializer='glorot_uniform',
五、配置模型保存位置及回调函数
#配置TensorBoard可视化网络训练图
## windows下,logdir路径不能加./否则报如下错误:
#windows下报错: Cannot stop profiling. No profiler is running.logdir = 'tensorboardLogs'
if not os.path.exists(logdir):
os.mkdir(logdir)
output_model_file = os.path.join(logdir, "MyFirstModel.h5")
callbacks = [#打开CMD密令窗口,进入工程所在目录;输入:
# tensorboard --logdir "./tensorboardLogs"启动可视化网页
tf.keras.callbacks.TensorBoard(log_dir=logdir),#定义TensorBoard对象
tf.keras.callbacks.ModelCheckpoint(output_model_file,save_best_only = True),
tf.keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3),]
六、生成或读入数据
# 生成训练数据
input_x = np.random.random((500, M))
trin_out = np.random.random((500, L))
#生成评估数据
val_in = np.random.random((200, M))
val_out = np.random.random((200, L))
七、编译模型
# 编译模型
model.compile(optimizer=tf.keras.optimizers.RMSprop(0.001),#学习率
loss=tf.keras.losses.categorical_crossentropy,metrics=['accuracy'])
说明:
优化器optimizers可选参数:SGD,RMSprop,Adagrad,Adadelta,Adam
损失函数losses可选参数:mean_squared_error、mean_absolute_error、mean_absolute_percentage_error、mean_squared_logarithmic_error、squared_hinge、hinge、categorical_crossentropy、binary_crossentropy、kullback_leibler_divergence、poisson、cosine_similarity、logcosh、categorical_hinge
八、模型拟合训练
#模型拟合
model.fit(input_x, trin_out, epochs=5, batch_size=100,
validation_data=(val_in, val_out),callbacks=callbacks)
九、训练过程
Train on 500 samples, validate on 200samples
Epoch 1/5
2020-07-19 11:50:45.670783: Itensorflow/core/profiler/lib/profiler_session.cc:184] Profiler session started.
100/500[=====>........................] - ETA: 9s - loss: 11.4654 - accuracy:0.0600
200/500[===========>..................] - ETA: 3s - loss: 11.5759 - accuracy:0.0900
500/500 [==============================]- 3s 6ms/sample - loss: 11.4808 - accuracy: 0.0960 - val_loss: 11.8636 -val_accuracy: 0.1000
Epoch 2/5
100/500[=====>........................] - ETA: 0s - loss: 11.9319 - accuracy:0.1100
500/500 [==============================]- 0s 190us/sample - loss: 11.6670 - accuracy: 0.0980 - val_loss: 12.0953 -val_accuracy: 0.1000
Epoch 3/5
100/500[=====>........................] - ETA: 0s - loss: 11.5618 - accuracy:0.1000
500/500 [==============================]- 0s 150us/sample - loss: 11.9153 - accuracy: 0.0980 - val_loss: 12.3733 -val_accuracy: 0.1000
Epoch 4/5
100/500[=====>........................] - ETA: 0s - loss: 11.8692 - accuracy:0.1600
500/500 [==============================]- 0s 150us/sample - loss: 12.2013 - accuracy: 0.1000 - val_loss: 12.6637 -val_accuracy: 0.1000
Epoch 5/5
100/500[=====>........................] - ETA: 0s - loss: 12.5327 - accuracy:0.0800
500/500 [==============================]- 0s 148us/sample - loss: 12.4866 - accuracy: 0.1000 - val_loss: 12.9562 -val_accuracy: 0.1000
十、查看TensorBoard
在项目路径栏输入cmd回车
在浏览器中输入http://localhost:6006/查看网络流图、训练信息和硬件使用情况
TensorFlow2第一篇测试用例到此结束~~~~~~~~~~