关于慕课第四讲中Fashion的各种优化

代码和上一讲只是相差了保存权重和增加了预测功能
1.权重保存的路径和读取方法

checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path+'.index'):
    print('---------------------load the model------------------')
    model.load_weights(checkpoint_save_path)

2.callback权重填写函数

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                save_best_only=True,
                                                save_weights_only=True)

3.训练时调用

history = model.fit(x_train,y_train,batch_size=32,epochs=10,validation_data=(x_test,y_test),validation_freq=2,callbacks=[cp_callback])

4 预测部分
①网络复现

checkpoint_save_path = "./checkpoint/mnist.ckpt"
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128,activation='relu'),
    tf.keras.layers.Dense(10,activation='softmax')
])
model.load_weights(checkpoint_save_path)

#预测函数

def preidct(img_path):
    img = Image.open(img_path)
    img =img.resize((28,28),Image.ANTIALIAS)
    img_arr = np.array(img.convert('L'))
    import matplotlib.pyplot as plt
#    plt.imshow(img,cmap='gray')
#    plt.show()
    for i in range(28):
        for j in range(28):
            if img_arr[i][j] < 200:
                img_arr[i][j] = 255
            else:
                img_arr[i][j] = 0
#    plt.imshow(img_arr,cmap='gray')
#    plt.show()
    img_arr = img_arr/255.0
    x_predict = img_arr[tf.newaxis,...]
    result = model.predict(x_predict)
    pred = tf.argmax(result,axis=1)
    print('\n')
    class_names=['T-shirt/top','Trouser','Pullover','Dress','Coat',
             'Sandal','Shirt','Sneaker','Bag','Ankle boot']
    pred = pred.numpy()
    print(class_names[pred[0]])

对于前面几张的预测效果
关于慕课第四讲中Fashion的各种优化_第1张图片

你可能感兴趣的:(关于慕课第四讲中Fashion的各种优化)