tensorflow2学习笔记 11keras工程化api

断点续训

  • 读取模型
    load_weights(路径文件名)
    生成ckpt的同时会生成index文件,可通过该文件是否存在判断是否有预训练模型生成
ckpt_path = "./mnist.ckpt"
if(os.path.exists(ckpt_path + ".index")):
    print("--load modle--")
    model.load_weights(ckpt_path)
  • 保存模型
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath = 路径文件名,
    save_weights_only=True/False,     #只保留模型参数
    save_best_only=True/False             #只保留最优模型
)
history = model.fit(x_train,y_train,batch_size=32,epochs=5,validation_data=(x_test,y_test),validation_freq=1,callbacks=[cp_callback])

查看训练参数

  • 提取可训练参数
    model.trainable_variables
  • 设置print输出格式
    np.set_printoptions(threshold=超过多少省略显示),此处若需要完全实现参数应设置为np.inf(表示无限大)

acc和loss可视化

#该可视化只可视化了当前运行的训练部分
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

plt.subplot(1,2,1)
plt.plot(acc,label='acc')
plt.plot(val_acc,label='val_acc')
plt.title('acc&&val_cac')
plt.legend()

plt.subplot(1,2,2)
plt.plot(loss,label='loss')
plt.plot(val_loss,label='val_loss')
plt.title('loss&&val_loss')
plt.legend()
plt.show()

前向传播

在训练完成后,使用网络生成预测结果

#复现模型
#加载参数
#前向传播获取结果
result = model.predict(x_predict)

在使用时x_predict需要在图片的原始维度前增加一个维度匹配batch维度

你可能感兴趣的:(tensorflow2学习笔记 11keras工程化api)