tensorflow2.6训练一个简单模型并保存和导出进行推理

设置回调函数,只保存权重

一、train.py

import warnings
warnings.filterwarnings("ignore")
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


from tensorflow.keras import (datasets,layers,models,callbacks)
import matplotlib.pyplot as plt



if __name__ == '__main__':
    (train_images,train_labels),(test_images,test_labels) = \
    datasets.mnist.load_data()
    train_images,test_images = train_images/255.0, test_images/255.0
    train_images = train_images.reshape(60000, 28, 28, 1)
    test_images = test_images.reshape(10000,28,28,1)

    for i in range(16):
        plt.subplot(4,4,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.imshow(train_images[i],cmap='gray')
        plt.xlabel(train_labels[i])
    plt.show()

    model = models.Sequential()
    model.add(layers.Conv2D(32,(3,3),activation='relu',
                            input_shape=(28,28,1)))
    model.add(layers.MaxPooling2D((2,2),))
    model.add(layers.Conv2D(64,(3,3),activation='relu'))
    model.add(layers.MaxPooling2D((2,2),))
    model.add(layers.Conv2D(64,(3,3),activation='relu'))

    model.add(layers.Flatten())
    model.add(layers.Dense(64,activation='relu'))
    model.add(layers.Dense(10,activation='softmax'))

    model.summary()

    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

    checkpointPath = "logs/ep{epoch:03d}_acc{accuracy:.2f}-vac{val_accuracy:.2f}.h5"
    callback = callbacks.ModelCheckpoint(filepath=checkpointPath,
                                         save_best_only=True,
                                         save_weights_only=True,
                                         verbose=1
                                         )

    history = model.fit(train_images,train_labels,epochs=5,
                        validation_data=(test_images,test_labels),
                        callbacks=[callback],
                        )
    

会在log/文件夹下保存.h5的权重文件

tensorflow2.6训练一个简单模型并保存和导出进行推理_第1张图片 

二、predict.py

import warnings
warnings.filterwarnings("ignore")
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

from tensorflow.keras import (models,layers,datasets)
import matplotlib.pyplot as plt
import numpy as np

if __name__ == '__main__':
    """ 恢复模型 """
    model = models.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu',
                            input_shape=(28, 28, 1)))
    model.add(layers.MaxPooling2D((2, 2), ))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2), ))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))

    model.add(layers.Flatten())
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dense(10, activation='softmax'))
    
    """ 载入保存的权重 """
    model.load_weights("logs/ep004_acc0.99-vac0.99.h5")
    (train_x,train_y),(test_x,test_y) = datasets.mnist.load_data()
    predictImg = test_x[0]
    predictLabel = test_y[0]
    
    """ 查看实际的图片及其对应的标签 """
    plt.imshow(predictImg,cmap='gray',)
    plt.xlabel(predictLabel)
    plt.xticks([])
    plt.yticks([])
    plt.show()
    
    """ 将数据处理成网络需要的格式 """
    predictImg = predictImg.reshape(1,28,28,1)
    """ 进行推理 """
    out = model.predict(predictImg)
    """ 解析结果 """
    preLable = np.argmax(out)
    print(preLable)

 tensorflow2.6训练一个简单模型并保存和导出进行推理_第2张图片

tensorflow2.6训练一个简单模型并保存和导出进行推理_第3张图片

 

 

你可能感兴趣的:(tensorflow2,tensorflow,人工智能,python)