tensorflow2装甲板id识别 3使用模型训练结果进行预测

预测实现

在上一篇文章中实现了装甲板id识别的网络训练并保存为了ckpt文件
https://www.jianshu.com/p/191337a9a819
虽然全连接的网络精度也就那样了,但是还是练习一下用现有的网络进行装甲板id预测

  • 复现网络
#网络搭建
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(500,activation='relu',kernel_regularizer=tf.keras.regularizers.l2()),
    tf.keras.layers.Dense(128,activation='relu',kernel_regularizer=tf.keras.regularizers.l2()),
    tf.keras.layers.Dense(50,activation='relu',kernel_regularizer=tf.keras.regularizers.l2()),
    tf.keras.layers.Dense(8,activation='softmax',kernel_regularizer=tf.keras.regularizers.l2())
])
  • 加载参数
#加载参数
ckpt_path = "./checkpoint/armor_id.ckpt"
if(os.path.exists(ckpt_path + ".index")):
    print("--load modle--")
    model.load_weights(ckpt_path)
else:
    print('----------------------------------------------error')
  • 输入数据处理
#图片读取与处理
img = tf.io.read_file (test_img_path)
img_raw = tf.image.decode_bmp (img)
img_raw = tf.cast(img_raw,dtype=tf.float32)
x_predict = tf.convert_to_tensor(img_raw)
x_predict = tf.reshape(x_predict,[1,-1])
  • 代码整体实现
import tensorflow as tf
import os

if __name__ == '__main__':
    test_img_path = './armor_dataset/8/8_47.bmp'
    
    #网络搭建
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(500,activation='relu',kernel_regularizer=tf.keras.regularizers.l2()),
        tf.keras.layers.Dense(128,activation='relu',kernel_regularizer=tf.keras.regularizers.l2()),
        tf.keras.layers.Dense(50,activation='relu',kernel_regularizer=tf.keras.regularizers.l2()),
        tf.keras.layers.Dense(8,activation='softmax',kernel_regularizer=tf.keras.regularizers.l2())
    ])
    #加载参数
    ckpt_path = "./checkpoint/armor_id.ckpt"
    if(os.path.exists(ckpt_path + ".index")):
        print("--load modle--")
        model.load_weights(ckpt_path)
    else:
        print('----------------------------------------------error')
        
    #图片读取与处理
    img = tf.io.read_file (test_img_path)
    img_raw = tf.image.decode_bmp (img)
    img_raw = tf.cast(img_raw,dtype=tf.float32)
    x_predict = tf.convert_to_tensor(img_raw)
    x_predict = tf.reshape(x_predict,[1,-1])
    
    #预测结果
    result = model.predict(x_predict)
    pred = tf.argmax(result,axis=1)   #获取概率最大数值的下标
    pred = pred + 1
    print("预测id为:")
    tf.print(pred)

遇到的坑

  • 实际未读入数据
    现象是每次输出结果随机变化
  • 使用tfrecord解码的数据和使用原始数据解码的数据不一致
    应当检查编码解码过程中的类型转换
    https://www.jianshu.com/p/51659ec687f8

测试结果

测试图片


51.png
5177.png
847.png

还进行了其他数字的测试
测试图片基本都实现了正确的预测

你可能感兴趣的:(tensorflow2装甲板id识别 3使用模型训练结果进行预测)