sklearn pipeline 和Gridsearch的使用

1. sklearn pipeline的使用

(1)简介

  当我们对训练集应用各种预处理操作时(特征标准化、主成分分析等等),
  我们都需要对测试集重复利用这些参数。

  pipeline 实现了对全部步骤的流式化封装和管理,可以很方便地使参数集在新数据集上被重复使用。

  pipeline 可以用于下面几处:

  • 模块化 Feature Transform,只需写很少的代码就能将新的 Feature 更新到训练集中。

  • 自动化 Grid Search,只要预先设定好使用的 Model 和参数的候选,就能自动搜索并记录最佳的 Model。

  • 自动化 Ensemble Generation,每隔一段时间将现有最好的 K 个 Model 拿来做 Ensemble。

(2)例子:

    注意pipeline中间每一步是 transformer,即它们必须包含 fit 和 transform 方法,或者 fit_transform。 

    最后一步是一个 Estimator,即最后一步模型要有 fit 方法,可以没有 transform 方法。

    然后用 Pipeline.fit对训练集进行训练,pipe_lr.fit(X_train, y_train) 
    再直接用 Pipeline.score 对测试集进行预测并评分 pipe_lr.score(X_test, y_test)

from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris 

# 获取iris数据集
iris = load_iris()
X_data = iris.data
y_data = iris.target

X_train, X_test, y_train, y_test = train_test_split(X_data, y_data, \
                                                    test_size = 0.25, random_state = 1)

from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline

# 构建pipeline
pipe_lr = Pipeline([('sc', StandardScaler()),
                    ('pca', PCA(n_components=2)),
                    ('clf', LogisticRegression(random_state=1))
                    ])
pipe_lr.fit(X_train, y_train)
print('Test accuracy: %.3f' % pipe_lr.score(X_test, y_test))

Test accuracy: 0.842


2. GridSearch和pipeline联合使用

 第一个参数estimator为进行训练的模型,parameters是模型的参数词典。parameters中的key为模型参数名称,value为模型参数的值的元组(包含多个可选择的值)。

from sklearn.svm import SVC
my_svc = SVC()
parameters = {'kernel':('linear','rbf'), 'C':[1, 2], 'gamma':[0.125, 0.5]}

from sklearn.model_selection import GridSearchCV
# 构建pipeline
pipe_svc = Pipeline([('sc', StandardScaler()),
                    ('pca', PCA(n_components=2)),
                    ('clf', GridSearchCV(my_svc, parameters, n_jobs=-1))
                    ])

pipe_svc.fit(X_train, y_train)
print('Test accuracy: %.3f' % pipe_svc.score(X_test, y_test))


参考文章:https://blog.csdn.net/sinat_26917383/article/details/77917881

你可能感兴趣的:(sklearn)