【sklearn】网格搜索 from sklearn.model_selection import GridSearchCV

      GridSearchCV用于系统地遍历模型的多种参数组合,通过交叉验证确定最佳参数。

1.GridSearchCV参数   

# 不常用的参数

  • pre_dispatch
    • 没看懂

  • refit    
    • 默认为True
    • 在参数搜索参数后,用最佳参数的结果fit一遍全部数据集
  • iid                 
    • 默认为True
    • 各个样本fold概率分布一致,误差估计为所有样本之和

# 常用的参数

  • cv
    • 默认为3
    • 指定fold个数,即默认三折交叉验证
  • verbose
    • 默认为0
    • 值为0时,不输出训练过程;值为1时,偶尔输出训练过程;值>1时,对每个子模型都输出训练过程
  • n_jobs
    • cpu个数
    • 值为-1时,使用全部CPU;值为1时,使用1个CPU;值为2时,使用2个CPU
  • estimator       
    • 分类器
  • param_grid    
    • 参数范围,值为列表/字典
  • scoring
准确度评价标准,socring参数选择链接

【sklearn】网格搜索 from sklearn.model_selection import GridSearchCV_第1张图片


2.常用属性  

  • best_score_
    • 最佳模型下的分数
  • best_params_
    • 最佳模型参数
  • grid_scores_
    • 模型不同参数下交叉验证的平均分
  • cv_results_   具体用法
    • 模型不同参数下交叉验证的结果
  • best_estimator_
    • 最佳分类器

注:grid_scores_在sklearn0.20版本中将被删除。使用cv_results_替代


3.常用函数

  • score(x_test,y_test)
    • 最佳模型在测试集下的分数


4.例子

  1 # -*- coding: utf-8 -*-
  2 """
  3 # 数据:20类新闻文本
  4 # 模型:svc
  5 # 调参:gridsearch
  6 """
  7 ### 加载模块
  8 import numpy as np
  9 import pandas as pd
 10 
 11 ### 载入数据
 12 from sklearn.datasets import fetch_20newsgroups                          # 20类新闻数据
 13 news = fetch_20newsgroups(subset='all')                                  # 生成20类新闻数据
 14 
 15 ### 数据分割
 16 from sklearn.cross_validation import train_test_split
 17 X_train, X_test, y_train, y_test = train_test_split(news.data[:300],
 18                                                     news.target[:300],
 19                                                     test_size=0.25,      # 测试集占比25%
 20                                                     random_state=33)     # 随机数
 21 ### pipe-line
 22 from sklearn.feature_extraction.text import TfidfVectorizer              # 特征提取
 23 from sklearn.svm import SVC                                              # 载入模型
 24 from sklearn.pipeline import Pipeline                                    # pipe_line模式
 25 clf = Pipeline([('vect', TfidfVectorizer(stop_words='english', analyzer='word')),
 26                 ('svc', SVC())])
 27 
 28 ### 网格搜索
 29 from sklearn.model_selection import GridSearchCV
 30 parameters = {'svc__gamma': np.logspace(-1, 1)}                           # 参数范围(字典类型)
 31 
 32 gs = GridSearchCV(clf,          # 模型
 33                   parameters,   # 参数字典
 34                   n_jobs=1,     # 使用1个cpu
 35                   verbose=0,    # 不打印中间过程
 36                   cv=5)         # 5折交叉验证
 37 
 38 gs.fit(X_train, y_train)        # 在训练集上进行网格搜索
 39 
 40 ### 最佳参数在测试集上模型分数
 41 print("best:%f using %s" % (gs.best_score_,gs.best_params_))
 42 
 43 ### 测试集下的分数
 44 print("test datasets score" % gs.score(X_test, y_test))
 45 
 46 ### 模型不同参数下的分数
 47 # 方式一(0.20版本将删除)
 48 print(gs.grid_scores_)
 49 
 50 # 方式二(0.20推荐的方式)
 51 means = gs.cv_results_['mean_test_score']
 52 params =  gs.cv_results_['params']
 53 
 54 for mean, param in zip(means,params):
 55     print("%f with: %r" % (mean,param))

转载于:https://www.cnblogs.com/wanglei5205/p/8581354.html

你可能感兴趣的:(【sklearn】网格搜索 from sklearn.model_selection import GridSearchCV)