18.2:tensorflow分类模型mobilenetv2训练(数据增强,保存模型,衰减学习率,tensorboard),预测图像(单张,批量预测),导出为pb完整示例

二、预测

经过一段时间的训练后在model_save文件夹下保存了下面的模型:

18.2:tensorflow分类模型mobilenetv2训练(数据增强,保存模型,衰减学习率,tensorboard),预测图像(单张,批量预测),导出为pb完整示例_第1张图片

他想使用model.ckpt-250000模型对数据进行预测。

他写了个predict.py程序为:

#coding:utf-8
import os, cv2
#os.environ["CUDA_VISIBLE_DEVICES"] = "-1"  # use cpu

import numpy as np
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import glob

import model


config = tf.ConfigProto()
config.gpu_options.allow_growth = True
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # use gpu 0

label_dict, label_dict_res = {}, {}
with open("label.txt", 'r') as f:
    for line in f.readlines():
        folder, label = line.strip().split(':')[0], line.strip().split(':')[1]
        label_dict[folder] = label
        label_dict_res[label] = folder
print(label_dict)

N_CLASSES = len(label_dict)
IMG_W = 224
IMG_H = IMG_W

def init_tf(logs_train_dir = './model_save/model.ckpt-250000'):
    global sess, pred, x
    # process image
    x = tf.placeholder(tf.float32, shape=[IMG_W, IMG_W, 3])
    x_norm = tf.image.per_image_standardization(x)
    x_4d = tf.reshape(x_norm, [-1, IMG_W, IMG_W, 3])
    # predict
    logit = model.MobileNetV2(x_4d, num_classes=N_CLASSES, is_training=False).output
    print("logit", np.shape(logit))
    #logit = model.model4(x_4d, N_CLASSES, is_trian=False)
    #logit = model.model2(x_4d, batch_size=1, n_classes=N_CLASSES)
    pred = tf.nn.softmax(logit)

    saver = tf.train.Saver()
    sess = tf.Session(config=config)
    saver.restore(sess, logs_train_dir)
    print('load model done...')

def evaluate_image(img_dir):
    # read image
    im = cv2.imread(img_dir)
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    im = cv2.resize(im, (IMG_W, IMG_W))
    image_array = np.array(im)

    prediction = sess.run(pred, feed_dict={x: image_array})
    max_index = np.argmax(prediction)
    pred_label = label_dict_res[str(max_index)]
    print("%s, predict: %s(index:%d), prob: %f" %(img_dir, pred_label, max_index, prediction[0][max_index]))
    

if __name__ == '__main__':
    init_tf()
    data_path = "/media/DATA2/sku_val"
    label = os.listdir(data_path)
    for l in label:
        if os.path.isfile(os.path.join(data_path, l)):
            continue
        for img in glob.glob(os.path.join(data_path, l, "*.jpg")):
            evaluate_image(img_dir=img)
    sess.close()

他嫌一张张输入有点慢,于是改了以下代码,让一次输入一个batch。命名为predict_batch.py

#coding:utf-8
import os, cv2, time
#os.environ["CUDA_VISIBLE_DEVICES"] = "-1"  # use cpu

import numpy as np
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import glob

import model

label_dict, label_dict_res = {}, {}
with open("label.txt", 'r') as f:
    for line in f.readlines():
        folder, label = line.strip().split(':')[0], line.strip().split(':')[1]
        label_dict[folder] = label
        label_dict_res[label] = folder
print(label_dict)

N_CLASSES = len(label_dict)
IMG_W = 224
IMG_H = IMG_W
batch_size = 16

os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # use gpu 0
config = tf.ConfigProto()
config.gpu_options.allow_growth = True


def get_imgpath(path):
    img_list = []
    for fpath , dirs , fs in os.walk(path):
        for f in fs:
            img_path = os.path.join(fpath , f)
            if os.path.dirname(img_path) == os.getcwd():
                continue
            if not os.path.isfile(img_path):
                continue
            if os.path.basename(img_path)[-3:] == "jpg":
                img_list.append(img_path)
    return img_list


def init_tf(logs_train_dir = './model_save/model.ckpt-140000'):
    global sess, pred, x
    # process image
    x = tf.placeholder(tf.float32, shape=[None, IMG_W, IMG_W, 3], name="input_1")
    # predict
    logit = model.MobileNetV2(x, num_classes=N_CLASSES, is_training=False).output
    #logit = model.model4(x, N_CLASSES, is_trian=False)
    #logit = model.model2(x_4d, batch_size=1, n_classes=N_CLASSES)
    pred = tf.nn.softmax(logit, name="pred")

    saver = tf.train.Saver()
    sess = tf.Session(config=config)
    saver.restore(sess, logs_train_dir)
    print('load model done...')

def evaluate_image(img_dir):
    # read and process image
    batch_img = []
    for img in img_dir:    
        im = cv2.imread(img)
        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
        im = cv2.resize(im, (IMG_W, IMG_W))
        im_mean = np.mean(im)
        stddev = max(np.std(im), 1.0/np.sqrt(IMG_W*IMG_H*3))
        im = (im - im_mean) / stddev
        image_array = np.array(im)
        batch_img.append(image_array)
    # output sotfmax
    prediction = sess.run(pred, feed_dict={x: batch_img})
    for i in range(len(img_dir)):
        img = img_dir[i]
        max_index = np.argmax(prediction[i])
        print("img:%s, predict: %s, prob: %f" % (img, label_dict_res[str(max_index)], prediction[i][max_index]))
    

if __name__ == '__main__':
    init_tf()
    data_path = "/media/DATA2/sku_val"
    img_list = get_imgpath(data_path)
    print("there are %d images in %s" %(len(img_list), data_path))
    total_batch = len(img_list)/batch_size
    start = time.time()
    for i in range(total_batch):
        print(str(i) + "-"*50)
        batch_img = img_list[i*batch_size: (i+1)*batch_size]
        evaluate_image(batch_img)
    print("time cost:", time.time()-start)
    sess.close()

下一篇:导出.pb文件并预测图像https://blog.csdn.net/u010397980/article/details/84932538

你可能感兴趣的:(tensorflow,tensorflow)