TensorFlow2-创建Sequential模型

TensorFlow2-创建Sequential模型_第1张图片

一、引入相关包

# coding: utf-8
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import os

二、设置网络维度

#设置网络层维度
M =  50 #输入数据维度
N = 32 #隐藏层维度
L = 10 #输出分类数量
三、声明Sequential模型
# Sequential模型堆叠
# Y = σ(Wx+B) σ=relu
model = tf.keras.Sequential()

四、添加网络隐含层

model.add(layers.Dense(N, activation='relu',
        kernel_initializer='RandomUniform',
        bias_initializer="RandomNormal",
        name="layer_1"))
# Y = σ(Vy) σ=softmax
# 不配置时使用默认初始化
model.add(layers.Dense(L, activation='softmax',name="layer_2"))

说明:
激活函数可选配置: softmax、elu、softplus、softsign、relu、tanh、sigmoid、hard_sigmoid、linear
kernel_initializer和bias_initializer ,可选参数如下。
#常数:zero、zeros(默认偏置项配置)、Zeros、one、ones、Ones、constant、Constant
#均匀分布:uniform、random_uniform、RandomUniform
#正态分布:normal、random_normal、RandomNormal
#截断的正态分布:truncated_normal、TruncatedNormal
#标准化:identity、Identity
#正交:orthogonal、Orthogonal
#正态化Glorot,即Xavier:glorot_normal、GlorotNormal
#Glorot均匀分布(默认权重配置):glorot_uniform、GlorotUniform
默认配置: bias_initializer=‘zeros’,kernel_initializer=‘glorot_uniform’,

五、配置模型保存位置及回调函数

#配置TensorBoard可视化网络训练图
## windows下,logdir路径不能加./否则报如下错误:
#windows下报错: Cannot stop profiling. No profiler is running.
logdir = 'tensorboardLogs'
if not os.path.exists(logdir):
    os.mkdir(logdir)
output_model_file = os.path.join(logdir, "MyFirstModel.h5")
callbacks = [
# 打开CMD密令窗口,进入工程所在目录;输入:
# tensorboard --logdir "./tensorboardLogs"启动可视化网页
    tf.keras.callbacks.TensorBoard(log_dir=logdir),# 定义TensorBoard对象    
	tf.keras.callbacks.ModelCheckpoint(output_model_file,save_best_only = True),
    tf.keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3),
]

六、生成或读入数据

# 生成训练数据
input_x = np.random.random((500, M))
trin_out = np.random.random((500, L))
#生成评估数据
val_in = np.random.random((200, M))
val_out = np.random.random((200, L))

七、编译模型

# 编译模型
model.compile(optimizer=tf.keras.optimizers.RMSprop(0.001),#学习率             loss=tf.keras.losses.categorical_crossentropy,#
                metrics=['accuracy'])

说明:

优化器optimizers可选参数:SGD,RMSprop,Adagrad,Adadelta,Adam
损失函数losses可选参数:mean_squared_error、mean_absolute_error、mean_absolute_percentage_error、mean_squared_logarithmic_error、squared_hinge、hinge、categorical_crossentropy、binary_crossentropy、kullback_leibler_divergence、poisson、cosine_similarity、logcosh、categorical_hinge

八、模型拟合训练

#模型拟合
model.fit(input_x, trin_out, epochs=5, batch_size=100,
                validation_data=(val_in, val_out),
                callbacks=callbacks)

九、训练过程

Train on 500 samples, validate on 200 samples
Epoch 1/5
2020-07-19 11:50:45.670783: I tensorflow/core/profiler/lib/profiler_session.cc:184] Profiler session started.
100/500 [=====>........................] - ETA: 9s - loss: 11.4654 - accuracy: 0.0600
200/500 [===========>..................] - ETA: 3s - loss: 11.5759 - accuracy: 0.0900
500/500 [==============================] - 3s 6ms/sample - loss: 11.4808 - accuracy: 0.0960 - val_loss: 11.8636 - val_accuracy: 0.1000
Epoch 2/5
100/500 [=====>........................] - ETA: 0s - loss: 11.9319 - accuracy: 0.1100
500/500 [==============================] - 0s 190us/sample - loss: 11.6670 - accuracy: 0.0980 - val_loss: 12.0953 - val_accuracy: 0.1000
Epoch 3/5
100/500 [=====>........................] - ETA: 0s - loss: 11.5618 - accuracy: 0.1000
500/500 [==============================] - 0s 150us/sample - loss: 11.9153 - accuracy: 0.0980 - val_loss: 12.3733 - val_accuracy: 0.1000
Epoch 4/5
100/500 [=====>........................] - ETA: 0s - loss: 11.8692 - accuracy: 0.1600
500/500 [==============================] - 0s 150us/sample - loss: 12.2013 - accuracy: 0.1000 - val_loss: 12.6637 - val_accuracy: 0.1000
Epoch 5/5
100/500 [=====>........................] - ETA: 0s - loss: 12.5327 - accuracy: 0.0800
500/500 [==============================] - 0s 148us/sample - loss: 12.4866 - accuracy: 0.1000 - val_loss: 12.9562 - val_accuracy: 0.1000

十、查看TensorBoard
在项目路径栏输入cmd回车

TensorFlow2-创建Sequential模型_第2张图片
TensorFlow2-创建Sequential模型_第3张图片
TensorFlow2-创建Sequential模型_第4张图片
在浏览器中输入http://localhost:6006/查看网络流图、训练信息和硬件使用情况
TensorFlow2-创建Sequential模型_第5张图片
TensorFlow2-创建Sequential模型_第6张图片
TensorFlow2-创建Sequential模型_第7张图片
TensorFlow2-创建Sequential模型_第8张图片

你可能感兴趣的:(TensorFlow测试用例,tensorflow,深度学习)