代码和上一讲只是相差了保存权重和增加了预测功能
1.权重保存的路径和读取方法
checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path+'.index'):
print('---------------------load the model------------------')
model.load_weights(checkpoint_save_path)
2.callback权重填写函数
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_best_only=True,
save_weights_only=True)
3.训练时调用
history = model.fit(x_train,y_train,batch_size=32,epochs=10,validation_data=(x_test,y_test),validation_freq=2,callbacks=[cp_callback])
4 预测部分
①网络复现
checkpoint_save_path = "./checkpoint/mnist.ckpt"
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128,activation='relu'),
tf.keras.layers.Dense(10,activation='softmax')
])
model.load_weights(checkpoint_save_path)
#预测函数
def preidct(img_path):
img = Image.open(img_path)
img =img.resize((28,28),Image.ANTIALIAS)
img_arr = np.array(img.convert('L'))
import matplotlib.pyplot as plt
# plt.imshow(img,cmap='gray')
# plt.show()
for i in range(28):
for j in range(28):
if img_arr[i][j] < 200:
img_arr[i][j] = 255
else:
img_arr[i][j] = 0
# plt.imshow(img_arr,cmap='gray')
# plt.show()
img_arr = img_arr/255.0
x_predict = img_arr[tf.newaxis,...]
result = model.predict(x_predict)
pred = tf.argmax(result,axis=1)
print('\n')
class_names=['T-shirt/top','Trouser','Pullover','Dress','Coat',
'Sandal','Shirt','Sneaker','Bag','Ankle boot']
pred = pred.numpy()
print(class_names[pred[0]])