设置回调函数,只保存权重
import warnings
warnings.filterwarnings("ignore")
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from tensorflow.keras import (datasets,layers,models,callbacks)
import matplotlib.pyplot as plt
if __name__ == '__main__':
(train_images,train_labels),(test_images,test_labels) = \
datasets.mnist.load_data()
train_images,test_images = train_images/255.0, test_images/255.0
train_images = train_images.reshape(60000, 28, 28, 1)
test_images = test_images.reshape(10000,28,28,1)
for i in range(16):
plt.subplot(4,4,i+1)
plt.xticks([])
plt.yticks([])
plt.imshow(train_images[i],cmap='gray')
plt.xlabel(train_labels[i])
plt.show()
model = models.Sequential()
model.add(layers.Conv2D(32,(3,3),activation='relu',
input_shape=(28,28,1)))
model.add(layers.MaxPooling2D((2,2),))
model.add(layers.Conv2D(64,(3,3),activation='relu'))
model.add(layers.MaxPooling2D((2,2),))
model.add(layers.Conv2D(64,(3,3),activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64,activation='relu'))
model.add(layers.Dense(10,activation='softmax'))
model.summary()
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
checkpointPath = "logs/ep{epoch:03d}_acc{accuracy:.2f}-vac{val_accuracy:.2f}.h5"
callback = callbacks.ModelCheckpoint(filepath=checkpointPath,
save_best_only=True,
save_weights_only=True,
verbose=1
)
history = model.fit(train_images,train_labels,epochs=5,
validation_data=(test_images,test_labels),
callbacks=[callback],
)
会在log/文件夹下保存.h5的权重文件
import warnings
warnings.filterwarnings("ignore")
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from tensorflow.keras import (models,layers,datasets)
import matplotlib.pyplot as plt
import numpy as np
if __name__ == '__main__':
""" 恢复模型 """
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu',
input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2), ))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2), ))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
""" 载入保存的权重 """
model.load_weights("logs/ep004_acc0.99-vac0.99.h5")
(train_x,train_y),(test_x,test_y) = datasets.mnist.load_data()
predictImg = test_x[0]
predictLabel = test_y[0]
""" 查看实际的图片及其对应的标签 """
plt.imshow(predictImg,cmap='gray',)
plt.xlabel(predictLabel)
plt.xticks([])
plt.yticks([])
plt.show()
""" 将数据处理成网络需要的格式 """
predictImg = predictImg.reshape(1,28,28,1)
""" 进行推理 """
out = model.predict(predictImg)
""" 解析结果 """
preLable = np.argmax(out)
print(preLable)