import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, datasets
import numpy as np
import matplotlib.pyplot as plt
import datetime
def prepare_mnist_features_and_labels(x, y):
x = tf.cast(x, tf.float32) / 255.0
y = tf.cast(y, tf.int64)
return x, y
def mnist_dataset():
(x, y), (x_val, y_val) = datasets.fashion_mnist.load_data()
x, x_val = np.expand_dims(x, axis=3), np.expand_dims(x_val, axis=3)
y = tf.one_hot(y, depth=10)
y_val = tf.one_hot(y_val, depth=10)
ds = tf.data.Dataset.from_tensor_slices((x, y))
ds = ds.map(prepare_mnist_features_and_labels)
ds = ds.shuffle(60000).batch(100)
ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.map(prepare_mnist_features_and_labels)
ds_val = ds_val.shuffle(10000).batch(100)
return ds,ds_val
class LearningRateExponentialDecay:
def __init__(self,initial_learning_rate,decay_epochs,decay_rate):
self.initial_learning_rate=initial_learning_rate
self.decay_epochs=decay_epochs
self.decay_rate=decay_rate
def __call__(self,epoch):
dtype =type(self.initial_learning_rate)
decay_epochs=np.array(self.decay_epochs).astype(dtype)
decay_rate=np.array(self.decay_rate).astype(dtype)
epoch = np.array(epoch).astype(dtype)
p = epoch/decay_epochs
lr = self.initial_learning_rate*np.power(decay_rate,p)
return lr
def save_log():
log_dir = os.path.join( "C:/Users/byroot/Desktop/test1/model_output", 'logs_{}'.format(datetime.datetime.now().strftime("%Y%m%d-%H%M%S")))
if not os.path.exists(log_dir):
os.makedirs(log_dir)
return log_dir
def callbackfunc():
# 模型保存格式默认是saved_model,可以自己定义更改原有类来保存hdf5
ckpt = tf.keras.callbacks.ModelCheckpoint(filepath=path, monitor='val_loss',
save_best_only=False, save_weights_only=False)
# 当模型训练不符合我们要求时停止训练,连续5个epoch验证集精度没有提高0.001%停
earlystop = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', min_delta=0.0001, patience=52)
lr_schedule = LearningRateExponentialDecay(initial_learning_rate=0.001, decay_epochs=1, decay_rate=0.96)
lr = tf.keras.callbacks.LearningRateScheduler(lr_schedule)
# 定义当loss出现nan或inf时停止训练的callback
terminate = tf.keras.callbacks.TerminateOnNaN()
# 降低学习率(要比学习率自动周期变化有更大变化和更长时间监控)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3,
min_delta=0.0001, min_lr=0)
# 保存训练过程中大数标量指标,与tensorboard同一个文件
csv_logger = tf.keras.callbacks.CSVLogger(os.path.join(log_dir, 'logs.log'), separator=',')
# 还要加入tensorboard的使用,这种方法记录的内容有限
tensorboard = tf.keras.callbacks.TensorBoard(log_dir=log_dir,
histogram_freq=1, # 对参数和激活做直方图,一定要有测试集
write_graph=True, # 模型结构图
write_images=True, # 把模型参数做为图片形式存到
update_freq='epoch', # epoch,batch,整数,太频的话会减慢速度
profile_batch=2, # 记录模型性能
embeddings_freq=1,
embeddings_metadata=None # 这个还不太清楚如何使用
)
callback = [ckpt, earlystop, lr, tensorboard, terminate, reduce_lr, csv_logger]
return callback
def saveModel():
model_json = model.to_json()
with open(os.path.join(log_dir, 'model_json.json'), 'w') as json_file:
json_file.write(model_json)
if __name__ == '__main__':
path = "C:/Users/byroot/Desktop/test1/" + "ckpt_epoch{epoch:02d}_val_acc{val_loss:.2f}.hdf5"
"""1.参数配置"""
log_dir = save_log()
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # or any {'0', '1', '2'}
"""2.数据处理"""
train_dataset, test_dataset = mnist_dataset()
"""3.设置模型"""
model = keras.Sequential([
layers.Reshape(target_shape=(28 * 28,), input_shape=(28, 28, 1 )),
layers.Dense(200, activation='relu'),
layers.Dense(200, activation='relu'),
layers.Dense(200, activation='relu'),
layers.Dense(10)])
# no need to use compile if you have no loss/optimizer/metrics involved here.
model.compile(optimizer=optimizers.Adam(0.001),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
"""4.画出模型结构图并保存"""
tf.keras.utils.plot_model(model,to_file=os.path.join('C:/Users/byroot/Desktop/test1/','model.png'),show_shapes=True,show_layer_names=True)
"""5.配置回调函数"""
callback = callbackfunc()
"""6.开始训练"""
model.fit(train_dataset.repeat(), epochs=1, steps_per_epoch=10,
validation_data=test_dataset.repeat(),
validation_steps=2,
callbacks = callback
)
"""7.打印模型结构"""
model.summary()
"""8.保存模型结构及配置参数"""
saveModel()
"""9.对模型在测试集上进行评估"""
metrics = model.evaluate(test_dataset)
print("val_loss:", metrics[0], "val_accuracy:", metrics[1])