python 检视过拟合之validation_curve

validation_curve: 展示某一个因子,不同取值的算法的得分
                            通过这种曲线可以更加直观看出改变模型中的参数,有没有出现过拟合

validation_curve(estimator, X, y, param_name, param_range, groups=None, cv=None, scoring=None, n_jobs=1, pre_dispatch='all', verbose=0)
                estimator:实现了fit 和 predict 方法的对象
                X : 训练的向量
                y : 目标相对于X分类或回归
                param_name:将被改变的变量名称
                param_range:param_name对应的变量的取值
                cv:如果传入整数,测试数据将分成对应的分数,其中一份作为cv集,其余n-1作为traning(默认为3份)
结果返回;(train_scores, test_scores)
                train_scores:训练集的得分
                test_scores;测试集的得分

print(__doc__)

import matplotlib.pyplot as plot
import numpy as np
import time

from sklearn.datasets import load_digits
from sklearn.svm import SVC
from sklearn.model_selection import validation_curve

startTime = time.clock()
#print(startTime)
digits = load_digits()  ##创建1797个64维的数字图像数据点,返回(data,target)
x, y = digits.data, digits.target

param_range = np.logspace(-6, -1, 5)
#print(param_range)
train_scores, test_scores = validation_curve(
    SVC(), ##
    x, y,param_name='gamma', param_range=param_range,
    cv=10,scoring='accuracy', n_jobs=1 )

train_scores_mean = np.mean(train_scores, axis=1)   #axis=1, 就算每一行的均值
train_scores_std = np.std(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
test_scores_std = np.std(test_scores, axis=1)

plot.title('Validation Curve with SVM')
plot.xlabel('$\gamma$')
plot.ylabel('Score')
plot.ylim(0.0, 1.1)
lw = 2

# plot.plot(param_range, train_scores_mean, 'o-', color='r', label='Training')
# plot.plot(param_range, test_scores_mean,'o-', color='g', label='Cross-validation')

plot.semilogx(param_range, train_scores_mean, label='Traing score',
              color='darkorange', lw=lw)
plot.fill_between(param_range, train_scores_mean - train_scores_std,
                  train_scores_mean + train_scores_std, alpha=0.2,
                  color='darkorange', lw=lw)
plot.semilogx(param_range, test_scores_mean, label='Cross-calidation score',
              color='navy', lw=lw)
plot.fill_between(param_range, test_scores_mean - test_scores_std,
                  test_scores_mean + test_scores_std, alpha=0.2,
                  color='navy', lw=lw)

plot.legend(loc='best')
#打出运行时间
print(time.clock() - startTime) #43.409535299999995
plot.show()

 结果:

python 检视过拟合之validation_curve_第1张图片

。。。待续

你可能感兴趣的:(Python,sklearn)