9.9 用极端随机森林ERF训练图像分类器

极端随机森林(Extremely Random Forest,ERF)

定义一个类来处理ERF训练。

这里利用标签编码器来对训练标签进行编码

class ERFTrainer(object):
    def __init__(self, X, label_words):
        self.le = preprocessing.LabelEncoder()
        self.clf = ExtraTreesClassifier(n_estimators=100,max_depth=16, random_state=0)

        y = self.encode_labels(label_words)
        self.clf.fit(np.asarray(X), y)
    # 定义一个标签分类器   进行训练
    def encode_labels(self, label_words):
        self.le.fit(label_words)
        return np.array(self.le.transform(label_words), dtype=np.float32)
    
    #对未知数据点进行分类
    def classify(self, X):
        label_nums = self.clf.predict(np.asarray(X))
        label_words = self.le.inverse_transform([int(x) for x in label_nums])
        return label_words

主函数

# -*- coding:utf8 -*-

import argparse
import pickle as pickle

import numpy as np
from sklearn.ensemble import ExtraTreesClassifier
from sklearn import preprocessing
  • 在pycharm里编辑输入信息 方便调试
if __name__== '__main__':
    feature_map_file = 'feature_map/9_8.pkl'
    model_file = 'model/train_2.pkl'

    # 加载9.8节输出的 feature_map
    with open(feature_map_file, 'rb') as f:
        feature_map = pickle.load(f)

    # 提取特征和标记
    label_words = [x['object_class'] for x in feature_map]
    dim_size = feature_map[0]['feature_vector'].shape[1]
    X = [np.reshape(x['feature_vector'], (dim_size,)) for x in feature_map]

    # 训练ERF分类器  并保存模型
    erf = ERFTrainer(X, label_words)
    if model_file:
        with open(model_file, 'wb') as f:
            pickle.dump(erf, f)
  • 命令行的方式
# 定义参数解析
def build_arg_parser():
    parser = argparse.ArgumentParser(description='Trains the classifier')
    parser.add_argument("--feature-map-file", dest="feature_map_file", required=True,
                        help="Input pickle file containing the feature map")
    parser.add_argument("--model-file", dest="model_file", required=False,
                        help="Output file where the trained model will be stored")
    return parser


if __name__== '__main__':
    args = build_arg_parser().parse_args()
    feature_map_file = args.feature_map_file
    model_file = args.model_file

    # Load the feature map
    with open(feature_map_file, 'rb') as f:
        feature_map = pickle.load(f)

    # Extract feature vectors and the labels
    label_words = [x['object_class'] for x in feature_map]
    dim_size = feature_map[0]['feature_vector'].shape[1]
    X = [np.reshape(x['feature_vector'], (dim_size,)) for x in feature_map]

    # Train the Extremely Random Forests classifier
    erf = ERFTrainer(X, label_words)
    if args.model_file:
        with open(args.model_file, 'wb') as f:
            pickle.dump(erf, f)

你可能感兴趣的:(9.9 用极端随机森林ERF训练图像分类器)