9.10 创建一个对象识别器

#定义一个类来处理图像提取标签
class ImageTagExtractor(object):
    def __init__(self,model_file,codebook_file):
        with open(model_file,'rb') as f:
            self.erf = pickle.load(f)
        with open(codebook_file,'rb') as f:
            self.kmeans,self.centroids = pickle.load(f)
    # 用训练好的模型来预测输出
    def predict(self,img,scaling_size):
        img = bf.resize_image(img,scaling_size)
        feature_vector = bf.BagOfWords().construct_feature(img,self.kmeans,self.centroids)
        print(self.erf.classify(feature_vector))
        image_tag = self.erf.classify(feature_vector)[0]
        return image_tag

主程序

import argparse
import pickle as pickle
import cv2
import numpy as np
from python_machine_learn.c9 import build_feature as bf
from python_machine_learn.c9.train import ERFTrainer

from sklearn.ensemble import ExtraTreesClassifier
from sklearn import preprocessing
from sklearn.cluster import KMeans
  • 在pycharm里编辑输入信息 方便调试
if __name__=='__main__':
    input_image = 'image/car_2.jpg'
    model_file = 'model/train_1.pkl'
    codebook_file = 'codebook/9_8.pkl'
    input_image = cv2.imread(input_image)

    scaling_size = 200
    print("\nOutput:", ImageTagExtractor(model_file,codebook_file).predict(input_image, scaling_size))
  • 命令行的方式
#定义一个参数解析器
def build_arg_parser():
    parser = argparse.ArgumentParser(description='Extracts features from each line and classifies the data')
    parser.add_argument('--input-image',dest='input_image',required = True,help = 'Input image to be classifed')
    parser.add_argument('--model-file',dest='model_file',required =True,help= 'Input file containing the trained model')
    parser.add_argument('--codebook-file',dest='codebook_file',required = True,help = 'Input the containing the codebook')
    return parser


if __name__=='__main__':
    args = build_arg_parser().parse_args()
    model_file = args.model_file
    codebook_file = args.codebook_file
    input_image = cv2.imread(args.input_image)

    scaling_size = 200
    print("\nOutput:", ImageTagExtractor(model_file,codebook_file).predict(input_image, scaling_size))

结果

car_2.jpg

你可能感兴趣的:(9.10 创建一个对象识别器)