机器学习(五)使用sklearn库的cross validation

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 25 20:10:51 2016

@author: SIrius
test sklearn
"""

import numpy as np
from sklearn import datasets
from sklearn.cross_validation import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.cross_validation import cross_val_score
import matplotlib.pyplot as plt

iris=datasets.load_iris()
data_x=iris.data
data_y=iris.target


x_train,x_test,y_train,y_test=train_test_split(
                              data_x,data_y,test_size=0.2) #把数据打乱
"""
选取一个合适的n_neighbors值
"""
k_range=range(1,40)
k_scorces=[]
for k in k_range:
    knn=KNeighborsClassifier(k)
    scorces=-cross_val_score(knn,data_x,data_y,cv=10,scoring='accuracy')
   # loss=-cross_val_score(knn,data_x,data_y,cv=10,scoring='mean_squared_errror')
    k_scorces.append(scorces.mean())

plt.plot(k_range,k_scorces)
plt.xlabel('value of knn n_neighbors')
plt.ylabel('cross_validation accuracy')
plt.show()

运行结果如下:
机器学习(五)使用sklearn库的cross validation_第1张图片




如何通过loss曲线来判断模型训练到何时时最佳(防止过拟合)

#
from sklearn.learning_curve import learning_curve
from sklearn.datasets import load_digits
from sklearn.svm import SVC

"""
怎样查看loss曲线,检查是否过拟合
"""
digits=load_digits()
x=digits.data
y=digits.target

train_sizes,train_loss,test_loss=learning_curve(SVC(gamma=0.01),#SVC model,不同的gamma值有不同的结果
                                                x,y,cv=10,scoring='mean_squared_error',
                       train_sizes=[0.1,0.25,0.5,0.75,1]) #在trainsets的10%、25%..的地方查看loss
train_loss_mean=-np.mean(train_loss,axis=1) #loss的值为负,加负号变正
test_loss_mean=-np.mean(test_loss,axis=1)

plt.plot(train_sizes,train_loss_mean,'o-',color='r',label='Trainning')
plt.plot(train_sizes,test_loss_mean,'o-',color='b',label='cross_validation')
plt.xlabel('trainning examples')
plt.ylabel('Loss')
plt.legend()
plt.show()

结果如下:
机器学习(五)使用sklearn库的cross validation_第2张图片

怎样选择一个合适的gamma值呢:

"""
--------------------------------------------------------------------
通过validation curve选取gamma值
"""
from sklearn.learning_curve import validation_curve

pa_range=np.logspace(-6,-2.3,5)  
train_loss,test_loss=validation_curve(
                             SVC(),x,y,param_name='gamma',param_range=pa_range,cv=10,
                              scoring='mean_squared_error')

train_loss_mean=-np.mean(train_loss,axis=1) #loss的值为负,加负号变正
test_loss_mean=-np.mean(test_loss,axis=1)

plt.plot(pa_range,train_loss_mean,'o-',color='r',label='Trainning')
plt.plot(pa_range,test_loss_mean,'o-',color='b',label='cross_validation')
plt.xlabel('gamma')
plt.ylabel('Loss')
plt.legend()
plt.show() 

结果如下
机器学习(五)使用sklearn库的cross validation_第3张图片

你可能感兴趣的:(机器学习)