4-6 网格搜索与k近邻算法中更多超参数
Notbook 示例
Notbook 源码
[1]
import numpy as np
from sklearn import datasets
[2]
digits = datasets.load_digits()
X = digits.data
y = digits.target
[3]
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size = 0.3 ,random_state=111 )
[4]
from sklearn.neighbors import KNeighborsClassifier
knn_clf = KNeighborsClassifier( n_neighbors = 6 )
knn_clf.fit(X_train,y_train)
knn_clf.score(X_test,y_test)
0.9833333333333333
Grid Search
[5]
param_gid = [
{
'weights': ['unifrom'],
'n_neighbors': [ i for i in range(1,11)]
},
{
'weights': ['distance'],
'n_neighbors': [ i for i in range(1,11)],
'p': [ i for i in range(1,6)]
}
]
[6]
knn_clf = KNeighborsClassifier()
[7]
from sklearn.model_selection import GridSearchCV
grid_search = GridSearchCV(knn_clf,param_gid)
[8]
%%time
grid_search.fit(X_train,y_train)
CPU times: total: 2min 15s
Wall time: 2min 18s
F:\anaconda\lib\site-packages\sklearn\model_selection\_validation.py:372: FitFailedWarning:
50 fits failed out of a total of 300.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.
Below are more details about the failures:
--------------------------------------------------------------------------------
50 fits failed with the following error:
Traceback (most recent call last):
File "F:\anaconda\lib\site-packages\sklearn\model_selection\_validation.py", line 680, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
File "F:\anaconda\lib\site-packages\sklearn\neighbors\_classification.py", line 196, in fit
self.weights = _check_weights(self.weights)
File "F:\anaconda\lib\site-packages\sklearn\neighbors\_base.py", line 82, in _check_weights
raise ValueError(
ValueError: weights not recognized: should be 'uniform', 'distance', or a callable function
warnings.warn(some_fits_failed_message, FitFailedWarning)
F:\anaconda\lib\site-packages\sklearn\model_selection\_search.py:969: UserWarning: One or more of the test scores are non-finite: [ nan nan nan nan nan nan
nan nan nan nan 0.98011446 0.98965724
0.98965724 0.99204452 0.98966041 0.98090811 0.98965724 0.98965724
0.99204452 0.98966041 0.98408904 0.98726997 0.98646999 0.98726681
0.98488585 0.98249542 0.98806678 0.98886359 0.98886359 0.98726997
0.98249542 0.98647948 0.98966041 0.98726997 0.98488902 0.98249542
0.98488585 0.9856795 0.98885727 0.9856795 0.98090179 0.98329539
0.9856795 0.98806362 0.98488269 0.98010181 0.98170176 0.9856795
0.98726997 0.98487637 0.97692405 0.98408904 0.98329223 0.98647948
0.98488585 0.97851135 0.98010814 0.98488269 0.98726997 0.98726997]
warnings.warn(
GridSearchCV(estimator=KNeighborsClassifier(),
param_grid=[{'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'weights': ['unifrom']},
{'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'p': [1, 2, 3, 4, 5], 'weights': ['distance']}])
[9]
grid_search.best_estimator_
KNeighborsClassifier(n_neighbors=1, p=4, weights='distance')
[10]
grid_search.best_score_
0.9920445203313729
[11]
grid_search.best_params_
{'n_neighbors': 1, 'p': 4, 'weights': 'distance'}
若 随机数为111,则 grid_search.best_score_ = 0.9920445203313729, grid_search.best_params_ = {'n_neighbors': 1, 'p': 4, 'weights': 'distance'}
[12]
knn_clf = grid_search.best_estimator_
[13]
knn_clf.predict(X_test)
array([7, 1, 2, 9, 5, 8, 6, 4, 8, 3, 8, 4, 4, 5, 7, 1, 6, 1, 0, 6, 6, 8,
8, 0, 8, 4, 5, 8, 0, 0, 3, 3, 5, 2, 1, 4, 8, 6, 7, 3, 3, 9, 6, 0,
4, 9, 7, 3, 8, 7, 4, 3, 5, 0, 3, 1, 7, 6, 5, 7, 6, 0, 9, 7, 7, 8,
2, 8, 6, 6, 1, 1, 2, 6, 4, 6, 4, 8, 6, 9, 8, 1, 3, 4, 4, 2, 0, 7,
6, 0, 8, 2, 0, 5, 8, 5, 3, 3, 7, 4, 7, 3, 4, 2, 4, 9, 1, 8, 5, 1,
2, 7, 0, 2, 8, 9, 7, 5, 7, 7, 8, 8, 9, 2, 3, 9, 7, 7, 8, 2, 5, 3,
2, 4, 0, 1, 4, 8, 7, 9, 6, 8, 1, 5, 2, 6, 1, 4, 1, 6, 5, 3, 4, 2,
2, 7, 0, 7, 1, 5, 4, 6, 1, 7, 4, 9, 6, 8, 5, 8, 4, 3, 3, 2, 5, 6,
7, 9, 0, 2, 0, 5, 4, 8, 0, 8, 6, 9, 7, 3, 1, 9, 4, 2, 7, 9, 4, 0,
5, 2, 8, 2, 9, 1, 8, 5, 4, 5, 7, 7, 5, 5, 0, 1, 4, 4, 6, 5, 7, 6,
0, 6, 7, 1, 9, 0, 6, 1, 2, 9, 1, 5, 3, 0, 2, 1, 0, 9, 3, 4, 1, 0,
9, 9, 2, 0, 5, 3, 6, 5, 5, 3, 9, 1, 2, 8, 7, 4, 9, 8, 8, 1, 3, 1,
6, 3, 0, 7, 2, 4, 7, 2, 5, 0, 6, 4, 7, 4, 1, 0, 3, 1, 8, 0, 5, 6,
9, 5, 5, 0, 6, 0, 5, 2, 9, 7, 2, 9, 1, 0, 3, 5, 8, 8, 0, 4, 3, 4,
6, 1, 6, 1, 7, 3, 3, 2, 3, 6, 7, 1, 0, 1, 9, 6, 6, 6, 8, 2, 3, 5,
9, 4, 4, 5, 3, 9, 7, 1, 3, 0, 0, 8, 6, 9, 7, 9, 6, 4, 2, 7, 2, 6,
5, 4, 1, 7, 9, 0, 1, 1, 7, 5, 3, 3, 7, 4, 9, 0, 8, 6, 0, 9, 1, 9,
7, 8, 8, 8, 6, 2, 1, 3, 0, 2, 3, 6, 8, 1, 6, 1, 3, 9, 6, 2, 5, 2,
9, 7, 7, 6, 5, 8, 0, 1, 8, 6, 3, 5, 0, 4, 3, 9, 9, 3, 4, 3, 7, 9,
2, 3, 5, 3, 9, 3, 1, 4, 7, 7, 1, 7, 4, 3, 0, 8, 0, 9, 6, 3, 9, 8,
3, 9, 9, 9, 4, 1, 6, 7, 7, 2, 0, 1, 0, 7, 5, 7, 6, 1, 5, 0, 6, 9,
5, 1, 2, 1, 7, 5, 2, 1, 8, 1, 8, 8, 2, 8, 6, 8, 7, 0, 9, 9, 6, 2,
0, 9, 6, 3, 4, 3, 0, 8, 5, 4, 8, 6, 4, 5, 2, 5, 6, 1, 0, 5, 7, 0,
9, 5, 3, 2, 9, 3, 0, 6, 4, 8, 3, 2, 3, 6, 6, 8, 1, 9, 4, 3, 1, 1,
4, 5, 4, 3, 7, 5, 3, 3, 7, 8, 1, 0])
[14]
knn_clf.score(X_test,y_test)
0.9907407407407407
[15]
%%time
grid_search = GridSearchCV(knn_clf,param_gid,n_jobs= 4, verbose = 2)
grid_search.fit(X_train,y_train)
# 创建多个分类器来比较,可以并行处理,n_jobs 为分配核的数量,默认为单核 1 .-1为全核。
# verbose,及时输出一些信息,值越大越详细
Fitting 5 folds for each of 60 candidates, totalling 300 fits
CPU times: total: 484 ms
Wall time: 1min 28s
F:\anaconda\lib\site-packages\sklearn\model_selection\_validation.py:372: FitFailedWarning:
50 fits failed out of a total of 300.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.
Below are more details about the failures:
--------------------------------------------------------------------------------
50 fits failed with the following error:
Traceback (most recent call last):
File "F:\anaconda\lib\site-packages\sklearn\model_selection\_validation.py", line 680, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
File "F:\anaconda\lib\site-packages\sklearn\neighbors\_classification.py", line 196, in fit
self.weights = _check_weights(self.weights)
File "F:\anaconda\lib\site-packages\sklearn\neighbors\_base.py", line 82, in _check_weights
raise ValueError(
ValueError: weights not recognized: should be 'uniform', 'distance', or a callable function
warnings.warn(some_fits_failed_message, FitFailedWarning)
F:\anaconda\lib\site-packages\sklearn\model_selection\_search.py:969: UserWarning: One or more of the test scores are non-finite: [ nan nan nan nan nan nan
nan nan nan nan 0.98011446 0.98965724
0.98965724 0.99204452 0.98966041 0.98090811 0.98965724 0.98965724
0.99204452 0.98966041 0.98408904 0.98726997 0.98646999 0.98726681
0.98488585 0.98249542 0.98806678 0.98886359 0.98886359 0.98726997
0.98249542 0.98647948 0.98966041 0.98726997 0.98488902 0.98249542
0.98488585 0.9856795 0.98885727 0.9856795 0.98090179 0.98329539
0.9856795 0.98806362 0.98488269 0.98010181 0.98170176 0.9856795
0.98726997 0.98487637 0.97692405 0.98408904 0.98329223 0.98647948
0.98488585 0.97851135 0.98010814 0.98488269 0.98726997 0.98726997]
warnings.warn(
GridSearchCV(estimator=KNeighborsClassifier(n_neighbors=1, p=4,
weights='distance'),
n_jobs=4,
param_grid=[{'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'weights': ['unifrom']},
{'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
'p': [1, 2, 3, 4, 5], 'weights': ['distance']}],
verbose=2)
4-7 数据归一化
Notbook 示例
Notbook 源码
数据归一化处理
[1]
import numpy as np
import matplotlib.pyplot as plt
最值归一化 Normalization
[2]
x = np.random.randint(0,100, size = 100)
[3]
x
array([ 2, 58, 55, 40, 68, 7, 72, 4, 50, 89, 96, 19, 71, 6, 41, 40, 63,
4, 2, 26, 79, 62, 62, 92, 49, 46, 75, 47, 44, 91, 40, 67, 38, 14,
13, 0, 93, 15, 20, 50, 38, 31, 41, 22, 36, 85, 64, 87, 98, 65, 31,
10, 46, 22, 86, 24, 68, 25, 11, 31, 22, 87, 84, 18, 58, 87, 4, 15,
64, 92, 70, 59, 74, 81, 47, 33, 1, 93, 6, 37, 62, 17, 58, 56, 98,
53, 2, 70, 36, 17, 21, 66, 1, 79, 60, 71, 89, 71, 61, 30])
[5]
( x - np.min(x)) / (np.max(x) - np.min(x))
array([0.02040816, 0.59183673, 0.56122449, 0.40816327, 0.69387755,
0.07142857, 0.73469388, 0.04081633, 0.51020408, 0.90816327,
0.97959184, 0.19387755, 0.7244898 , 0.06122449, 0.41836735,
0.40816327, 0.64285714, 0.04081633, 0.02040816, 0.26530612,
0.80612245, 0.63265306, 0.63265306, 0.93877551, 0.5 ,
0.46938776, 0.76530612, 0.47959184, 0.44897959, 0.92857143,
0.40816327, 0.68367347, 0.3877551 , 0.14285714, 0.13265306,
0. , 0.94897959, 0.15306122, 0.20408163, 0.51020408,
0.3877551 , 0.31632653, 0.41836735, 0.2244898 , 0.36734694,
0.86734694, 0.65306122, 0.8877551 , 1. , 0.66326531,
0.31632653, 0.10204082, 0.46938776, 0.2244898 , 0.87755102,
0.24489796, 0.69387755, 0.25510204, 0.1122449 , 0.31632653,
0.2244898 , 0.8877551 , 0.85714286, 0.18367347, 0.59183673,
0.8877551 , 0.04081633, 0.15306122, 0.65306122, 0.93877551,
0.71428571, 0.60204082, 0.75510204, 0.82653061, 0.47959184,
0.33673469, 0.01020408, 0.94897959, 0.06122449, 0.37755102,
0.63265306, 0.17346939, 0.59183673, 0.57142857, 1. ,
0.54081633, 0.02040816, 0.71428571, 0.36734694, 0.17346939,
0.21428571, 0.67346939, 0.01020408, 0.80612245, 0.6122449 ,
0.7244898 , 0.90816327, 0.7244898 , 0.62244898, 0.30612245])
[6]
X = np.random.randint(0,100,(50,2))
[7]
X[:10,:]
array([[19, 14],
[23, 82],
[ 4, 17],
[44, 58],
[23, 91],
[46, 17],
[34, 25],
[29, 39],
[69, 61],
[70, 25]])
[9]
X = np.array(X,dtype = float)
[10]
X[:10,:]
array([[19., 14.],
[23., 82.],
[ 4., 17.],
[44., 58.],
[23., 91.],
[46., 17.],
[34., 25.],
[29., 39.],
[69., 61.],
[70., 25.]])
[18]
X[:,0] = (X[:,0] - np.min(X[:,0])) / ( np.max(X[:,0]) - np.min(X[:,0]))
[19]
X[:,1] = (X[:,1] - np.min(X[:,1])) / ( np.max(X[:,1]) - np.min(X[:,1]))
[20]
X[:10,:]
array([[0.19191919, 0.11458333],
[0.23232323, 0.82291667],
[0.04040404, 0.14583333],
[0.44444444, 0.57291667],
[0.23232323, 0.91666667],
[0.46464646, 0.14583333],
[0.34343434, 0.22916667],
[0.29292929, 0.375 ],
[0.6969697 , 0.60416667],
[0.70707071, 0.22916667]])
[21]
plt.scatter(X[:,0],X[:,1])
[22]
np.mean(X[:,0])
0.4503030303030303
[26]
np.std(X[:,0])
0.32224653972392703
[24]
np.mean(X[:,1])
0.4503030303030303
[27]
np.std(X[:,1])
0.3004160887650475
均值方差归一化
[28]
X2 = np.random.randint(0,100,(50,2))
[29]
X2 = np.array(X2,dtype = float)
[36]
X2[:,0] = (X2[:,0] - np.mean(X2[:,0])) / np.std(X2[:,0])
[37]
X2[:,1] = (X2[:,1] - np.mean(X2[:,1])) / np.std(X2[:,1])
[38]
plt.scatter(X2[:,0],X2[:,1])
[39]
np.mean(X2[:,0])
0.0
[40]
np.std(X2[:,0])
1.0
[41]
np.mean(X2[:,1])
-4.4408920985006264e-17
[42]
np.std(X2[:,1])
1.0
4-8 scikit-learn中的Scaler
Notbook 示例
notbook 源码
[1]
import numpy as np
from sklearn import datasets
[2]
iris = datasets.load_iris()
[3]
X = iris.data
y = iris.target
[4]
X[:10,:]
array([[5.1, 3.5, 1.4, 0.2],
[4.9, 3. , 1.4, 0.2],
[4.7, 3.2, 1.3, 0.2],
[4.6, 3.1, 1.5, 0.2],
[5. , 3.6, 1.4, 0.2],
[5.4, 3.9, 1.7, 0.4],
[4.6, 3.4, 1.4, 0.3],
[5. , 3.4, 1.5, 0.2],
[4.4, 2.9, 1.4, 0.2],
[4.9, 3.1, 1.5, 0.1]])
[5]
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size = 0.3 ,random_state=666 )
scikit-learn中的StandardScaler
[6]
from sklearn.preprocessing import StandardScaler
[7]
standardScaler = StandardScaler()
[8]
standardScaler.fit(X_train)
StandardScaler()
[9]
standardScaler.mean_
array([5.81619048, 3.08761905, 3.66952381, 1.15714286])
[10]
standardScaler.scale_ # .std 旧形式已经弃用 , scale表示数据分布范围
array([0.80747977, 0.43789436, 1.76166176, 0.75464998])
[11]
standardScaler.transform(X_train)
array([[-0.63926119, 1.39846731, -1.2315212 , -1.26832688],
[-1.01078752, 0.94173615, -1.17475662, -0.73827982],
[-1.75384019, -0.42845732, -1.28828579, -1.26832688],
[-0.02005063, -0.88518848, 0.13082885, 0.05679076],
[-0.7631033 , 0.71337057, -1.28828579, -1.26832688],
[-1.50615597, 0.71337057, -1.28828579, -1.13581511],
[ 0.84684415, 0.25663941, 0.81200388, 1.11688486],
[-0.14389274, -0.42845732, 0.30112261, 0.18930252],
[ 0.97068626, -0.20009175, 0.41465178, 0.32181428],
[ 0.2276336 , -0.42845732, 0.47141637, 0.45432605],
[-1.38231385, 0.25663941, -1.17475662, -1.26832688],
[-1.13462963, 1.17010173, -1.28828579, -1.40083864],
[ 1.09452838, 0.02827383, 1.09582681, 1.64693191],
[ 0.59915993, -0.88518848, 0.69847471, 0.85186133],
[ 0.35147571, -0.6568229 , 0.58494554, 0.05679076],
[ 0.47531782, -0.6568229 , 0.64171013, 0.85186133],
[-0.14389274, 2.99702636, -1.2315212 , -1.00330335],
[ 0.59915993, -1.34191964, 0.69847471, 0.45432605],
[ 0.72300204, -0.42845732, 0.3578872 , 0.18930252],
[-0.88694541, 1.62683289, -1.00446286, -1.00330335],
[ 1.21837049, -0.6568229 , 0.64171013, 0.32181428],
[-0.88694541, 0.94173615, -1.28828579, -1.13581511],
[-1.8776823 , -0.20009175, -1.45857955, -1.40083864],
[ 0.10379148, -0.20009175, 0.81200388, 0.85186133],
[ 0.72300204, -0.6568229 , 1.09582681, 1.24939662],
[-0.26773485, -0.6568229 , 0.69847471, 1.11688486],
[-0.39157696, -1.57028522, 0.01729968, -0.20823277],
[ 1.3422126 , 0.02827383, 0.69847471, 0.45432605],
[ 0.59915993, 0.71337057, 1.09582681, 1.64693191],
[ 0.84684415, -0.20009175, 1.20935598, 1.38190839],
[-0.14389274, 1.62683289, -1.11799203, -1.13581511],
[ 0.97068626, -0.42845732, 0.52818096, 0.18930252],
[ 1.09452838, 0.48500499, 1.1525914 , 1.77944368],
[-1.25847174, -0.20009175, -1.28828579, -1.40083864],
[-1.01078752, 1.17010173, -1.28828579, -1.26832688],
[ 0.2276336 , -0.20009175, 0.64171013, 0.85186133],
[-1.01078752, -0.20009175, -1.17475662, -1.26832688],
[ 0.35147571, -0.20009175, 0.69847471, 0.85186133],
[ 0.72300204, 0.02827383, 1.03906223, 0.85186133],
[-0.88694541, 1.39846731, -1.2315212 , -1.00330335],
[-0.14389274, -0.20009175, 0.30112261, 0.05679076],
[-1.01078752, 0.94173615, -1.34505037, -1.13581511],
[-0.88694541, 1.62683289, -1.2315212 , -1.13581511],
[-1.50615597, 0.25663941, -1.28828579, -1.26832688],
[-0.51541907, -0.20009175, 0.47141637, 0.45432605],
[ 0.84684415, -0.6568229 , 0.52818096, 0.45432605],
[ 0.35147571, -0.6568229 , 0.18759344, 0.18930252],
[-1.25847174, 0.71337057, -1.17475662, -1.26832688],
[-0.88694541, 0.48500499, -1.11799203, -0.87079159],
[-0.02005063, -0.88518848, 0.81200388, 0.9843731 ],
[-0.26773485, -0.20009175, 0.24435803, 0.18930252],
[ 0.59915993, -0.6568229 , 0.81200388, 0.45432605],
[ 1.09452838, 0.48500499, 1.1525914 , 1.24939662],
[ 1.71373893, -0.20009175, 1.20935598, 0.58683781],
[ 1.09452838, -0.20009175, 0.86876847, 1.51442015],
[-1.13462963, 0.02827383, -1.2315212 , -1.40083864],
[-1.13462963, -1.34191964, 0.47141637, 0.71934957],
[-0.14389274, -1.34191964, 0.7552393 , 1.11688486],
[-1.13462963, -1.57028522, -0.20975866, -0.20823277],
[-0.39157696, -1.57028522, 0.07406427, -0.07572101],
[ 1.09452838, -1.34191964, 1.20935598, 0.85186133],
[ 0.84684415, -0.20009175, 1.03906223, 0.85186133],
[-0.14389274, -1.11355406, -0.09622949, -0.20823277],
[ 0.2276336 , -2.02701638, 0.7552393 , 0.45432605],
[ 1.09452838, 0.02827383, 0.58494554, 0.45432605],
[-1.13462963, 0.02827383, -1.2315212 , -1.26832688],
[ 0.59915993, -1.34191964, 0.7552393 , 0.9843731 ],
[-1.38231385, 0.25663941, -1.34505037, -1.26832688],
[ 0.2276336 , -0.88518848, 0.81200388, 0.58683781],
[-0.02005063, -1.11355406, 0.18759344, 0.05679076],
[ 1.3422126 , 0.25663941, 1.1525914 , 1.51442015],
[-1.75384019, -0.20009175, -1.34505037, -1.26832688],
[ 1.58989682, -0.20009175, 1.26612057, 1.24939662],
[ 1.21837049, 0.25663941, 1.26612057, 1.51442015],
[-0.7631033 , 0.94173615, -1.2315212 , -1.26832688],
[ 2.58063371, 1.62683289, 1.5499435 , 1.11688486],
[ 0.72300204, -0.6568229 , 1.09582681, 1.38190839],
[-0.26773485, -0.42845732, -0.0394649 , 0.18930252],
[-0.39157696, 2.5402952 , -1.28828579, -1.26832688],
[-1.25847174, -0.20009175, -1.28828579, -1.13581511],
[ 0.59915993, -0.42845732, 1.09582681, 0.85186133],
[-1.75384019, 0.25663941, -1.34505037, -1.26832688],
[-0.51541907, 1.85519847, -1.11799203, -1.00330335],
[-1.01078752, 0.71337057, -1.17475662, -1.00330335],
[ 1.09452838, -0.20009175, 0.7552393 , 0.71934957],
[-0.51541907, 1.85519847, -1.34505037, -1.00330335],
[ 2.33294949, -0.6568229 , 1.72023726, 1.11688486],
[-0.26773485, -0.88518848, 0.30112261, 0.18930252],
[ 1.21837049, -0.20009175, 1.03906223, 1.24939662],
[-0.39157696, 0.94173615, -1.34505037, -1.26832688],
[-1.25847174, 0.71337057, -1.00446286, -1.26832688],
[-0.51541907, 0.71337057, -1.11799203, -1.26832688],
[ 2.33294949, 1.62683289, 1.72023726, 1.38190839],
[ 1.3422126 , 0.02827383, 0.98229764, 1.24939662],
[-0.26773485, -1.34191964, 0.13082885, -0.07572101],
[-0.88694541, 0.71337057, -1.2315212 , -1.26832688],
[-0.88694541, 1.62683289, -1.17475662, -1.26832688],
[ 0.35147571, -0.42845732, 0.58494554, 0.32181428],
[-0.02005063, 2.08356405, -1.40181496, -1.26832688],
[-1.01078752, -2.48374754, -0.09622949, -0.20823277],
[ 0.72300204, 0.25663941, 0.47141637, 0.45432605],
[ 0.35147571, -0.20009175, 0.52818096, 0.32181428],
[ 0.10379148, 0.25663941, 0.64171013, 0.85186133],
[ 0.2276336 , -2.02701638, 0.18759344, -0.20823277],
[ 1.96142316, -0.6568229 , 1.37964974, 0.9843731 ]])
[12]
X_train
array([[5.3, 3.7, 1.5, 0.2],
[5. , 3.5, 1.6, 0.6],
[4.4, 2.9, 1.4, 0.2],
[5.8, 2.7, 3.9, 1.2],
[5.2, 3.4, 1.4, 0.2],
[4.6, 3.4, 1.4, 0.3],
[6.5, 3.2, 5.1, 2. ],
[5.7, 2.9, 4.2, 1.3],
[6.6, 3. , 4.4, 1.4],
[6. , 2.9, 4.5, 1.5],
[4.7, 3.2, 1.6, 0.2],
[4.9, 3.6, 1.4, 0.1],
[6.7, 3.1, 5.6, 2.4],
[6.3, 2.7, 4.9, 1.8],
[6.1, 2.8, 4.7, 1.2],
[6.2, 2.8, 4.8, 1.8],
[5.7, 4.4, 1.5, 0.4],
[6.3, 2.5, 4.9, 1.5],
[6.4, 2.9, 4.3, 1.3],
[5.1, 3.8, 1.9, 0.4],
[6.8, 2.8, 4.8, 1.4],
[5.1, 3.5, 1.4, 0.3],
[4.3, 3. , 1.1, 0.1],
[5.9, 3. , 5.1, 1.8],
[6.4, 2.8, 5.6, 2.1],
[5.6, 2.8, 4.9, 2. ],
[5.5, 2.4, 3.7, 1. ],
[6.9, 3.1, 4.9, 1.5],
[6.3, 3.4, 5.6, 2.4],
[6.5, 3. , 5.8, 2.2],
[5.7, 3.8, 1.7, 0.3],
[6.6, 2.9, 4.6, 1.3],
[6.7, 3.3, 5.7, 2.5],
[4.8, 3. , 1.4, 0.1],
[5. , 3.6, 1.4, 0.2],
[6. , 3. , 4.8, 1.8],
[5. , 3. , 1.6, 0.2],
[6.1, 3. , 4.9, 1.8],
[6.4, 3.1, 5.5, 1.8],
[5.1, 3.7, 1.5, 0.4],
[5.7, 3. , 4.2, 1.2],
[5. , 3.5, 1.3, 0.3],
[5.1, 3.8, 1.5, 0.3],
[4.6, 3.2, 1.4, 0.2],
[5.4, 3. , 4.5, 1.5],
[6.5, 2.8, 4.6, 1.5],
[6.1, 2.8, 4. , 1.3],
[4.8, 3.4, 1.6, 0.2],
[5.1, 3.3, 1.7, 0.5],
[5.8, 2.7, 5.1, 1.9],
[5.6, 3. , 4.1, 1.3],
[6.3, 2.8, 5.1, 1.5],
[6.7, 3.3, 5.7, 2.1],
[7.2, 3. , 5.8, 1.6],
[6.7, 3. , 5.2, 2.3],
[4.9, 3.1, 1.5, 0.1],
[4.9, 2.5, 4.5, 1.7],
[5.7, 2.5, 5. , 2. ],
[4.9, 2.4, 3.3, 1. ],
[5.5, 2.4, 3.8, 1.1],
[6.7, 2.5, 5.8, 1.8],
[6.5, 3. , 5.5, 1.8],
[5.7, 2.6, 3.5, 1. ],
[6. , 2.2, 5. , 1.5],
[6.7, 3.1, 4.7, 1.5],
[4.9, 3.1, 1.5, 0.2],
[6.3, 2.5, 5. , 1.9],
[4.7, 3.2, 1.3, 0.2],
[6. , 2.7, 5.1, 1.6],
[5.8, 2.6, 4. , 1.2],
[6.9, 3.2, 5.7, 2.3],
[4.4, 3. , 1.3, 0.2],
[7.1, 3. , 5.9, 2.1],
[6.8, 3.2, 5.9, 2.3],
[5.2, 3.5, 1.5, 0.2],
[7.9, 3.8, 6.4, 2. ],
[6.4, 2.8, 5.6, 2.2],
[5.6, 2.9, 3.6, 1.3],
[5.5, 4.2, 1.4, 0.2],
[4.8, 3. , 1.4, 0.3],
[6.3, 2.9, 5.6, 1.8],
[4.4, 3.2, 1.3, 0.2],
[5.4, 3.9, 1.7, 0.4],
[5. , 3.4, 1.6, 0.4],
[6.7, 3. , 5. , 1.7],
[5.4, 3.9, 1.3, 0.4],
[7.7, 2.8, 6.7, 2. ],
[5.6, 2.7, 4.2, 1.3],
[6.8, 3. , 5.5, 2.1],
[5.5, 3.5, 1.3, 0.2],
[4.8, 3.4, 1.9, 0.2],
[5.4, 3.4, 1.7, 0.2],
[7.7, 3.8, 6.7, 2.2],
[6.9, 3.1, 5.4, 2.1],
[5.6, 2.5, 3.9, 1.1],
[5.1, 3.4, 1.5, 0.2],
[5.1, 3.8, 1.6, 0.2],
[6.1, 2.9, 4.7, 1.4],
[5.8, 4. , 1.2, 0.2],
[5. , 2. , 3.5, 1. ],
[6.4, 3.2, 4.5, 1.5],
[6.1, 3. , 4.6, 1.4],
[5.9, 3.2, 4.8, 1.8],
[6. , 2.2, 4. , 1. ],
[7.4, 2.8, 6.1, 1.9]])
[13]
X_train = standardScaler.transform(X_train)
[14]
X_train
array([[-0.63926119, 1.39846731, -1.2315212 , -1.26832688],
[-1.01078752, 0.94173615, -1.17475662, -0.73827982],
[-1.75384019, -0.42845732, -1.28828579, -1.26832688],
[-0.02005063, -0.88518848, 0.13082885, 0.05679076],
[-0.7631033 , 0.71337057, -1.28828579, -1.26832688],
[-1.50615597, 0.71337057, -1.28828579, -1.13581511],
[ 0.84684415, 0.25663941, 0.81200388, 1.11688486],
[-0.14389274, -0.42845732, 0.30112261, 0.18930252],
[ 0.97068626, -0.20009175, 0.41465178, 0.32181428],
[ 0.2276336 , -0.42845732, 0.47141637, 0.45432605],
[-1.38231385, 0.25663941, -1.17475662, -1.26832688],
[-1.13462963, 1.17010173, -1.28828579, -1.40083864],
[ 1.09452838, 0.02827383, 1.09582681, 1.64693191],
[ 0.59915993, -0.88518848, 0.69847471, 0.85186133],
[ 0.35147571, -0.6568229 , 0.58494554, 0.05679076],
[ 0.47531782, -0.6568229 , 0.64171013, 0.85186133],
[-0.14389274, 2.99702636, -1.2315212 , -1.00330335],
[ 0.59915993, -1.34191964, 0.69847471, 0.45432605],
[ 0.72300204, -0.42845732, 0.3578872 , 0.18930252],
[-0.88694541, 1.62683289, -1.00446286, -1.00330335],
[ 1.21837049, -0.6568229 , 0.64171013, 0.32181428],
[-0.88694541, 0.94173615, -1.28828579, -1.13581511],
[-1.8776823 , -0.20009175, -1.45857955, -1.40083864],
[ 0.10379148, -0.20009175, 0.81200388, 0.85186133],
[ 0.72300204, -0.6568229 , 1.09582681, 1.24939662],
[-0.26773485, -0.6568229 , 0.69847471, 1.11688486],
[-0.39157696, -1.57028522, 0.01729968, -0.20823277],
[ 1.3422126 , 0.02827383, 0.69847471, 0.45432605],
[ 0.59915993, 0.71337057, 1.09582681, 1.64693191],
[ 0.84684415, -0.20009175, 1.20935598, 1.38190839],
[-0.14389274, 1.62683289, -1.11799203, -1.13581511],
[ 0.97068626, -0.42845732, 0.52818096, 0.18930252],
[ 1.09452838, 0.48500499, 1.1525914 , 1.77944368],
[-1.25847174, -0.20009175, -1.28828579, -1.40083864],
[-1.01078752, 1.17010173, -1.28828579, -1.26832688],
[ 0.2276336 , -0.20009175, 0.64171013, 0.85186133],
[-1.01078752, -0.20009175, -1.17475662, -1.26832688],
[ 0.35147571, -0.20009175, 0.69847471, 0.85186133],
[ 0.72300204, 0.02827383, 1.03906223, 0.85186133],
[-0.88694541, 1.39846731, -1.2315212 , -1.00330335],
[-0.14389274, -0.20009175, 0.30112261, 0.05679076],
[-1.01078752, 0.94173615, -1.34505037, -1.13581511],
[-0.88694541, 1.62683289, -1.2315212 , -1.13581511],
[-1.50615597, 0.25663941, -1.28828579, -1.26832688],
[-0.51541907, -0.20009175, 0.47141637, 0.45432605],
[ 0.84684415, -0.6568229 , 0.52818096, 0.45432605],
[ 0.35147571, -0.6568229 , 0.18759344, 0.18930252],
[-1.25847174, 0.71337057, -1.17475662, -1.26832688],
[-0.88694541, 0.48500499, -1.11799203, -0.87079159],
[-0.02005063, -0.88518848, 0.81200388, 0.9843731 ],
[-0.26773485, -0.20009175, 0.24435803, 0.18930252],
[ 0.59915993, -0.6568229 , 0.81200388, 0.45432605],
[ 1.09452838, 0.48500499, 1.1525914 , 1.24939662],
[ 1.71373893, -0.20009175, 1.20935598, 0.58683781],
[ 1.09452838, -0.20009175, 0.86876847, 1.51442015],
[-1.13462963, 0.02827383, -1.2315212 , -1.40083864],
[-1.13462963, -1.34191964, 0.47141637, 0.71934957],
[-0.14389274, -1.34191964, 0.7552393 , 1.11688486],
[-1.13462963, -1.57028522, -0.20975866, -0.20823277],
[-0.39157696, -1.57028522, 0.07406427, -0.07572101],
[ 1.09452838, -1.34191964, 1.20935598, 0.85186133],
[ 0.84684415, -0.20009175, 1.03906223, 0.85186133],
[-0.14389274, -1.11355406, -0.09622949, -0.20823277],
[ 0.2276336 , -2.02701638, 0.7552393 , 0.45432605],
[ 1.09452838, 0.02827383, 0.58494554, 0.45432605],
[-1.13462963, 0.02827383, -1.2315212 , -1.26832688],
[ 0.59915993, -1.34191964, 0.7552393 , 0.9843731 ],
[-1.38231385, 0.25663941, -1.34505037, -1.26832688],
[ 0.2276336 , -0.88518848, 0.81200388, 0.58683781],
[-0.02005063, -1.11355406, 0.18759344, 0.05679076],
[ 1.3422126 , 0.25663941, 1.1525914 , 1.51442015],
[-1.75384019, -0.20009175, -1.34505037, -1.26832688],
[ 1.58989682, -0.20009175, 1.26612057, 1.24939662],
[ 1.21837049, 0.25663941, 1.26612057, 1.51442015],
[-0.7631033 , 0.94173615, -1.2315212 , -1.26832688],
[ 2.58063371, 1.62683289, 1.5499435 , 1.11688486],
[ 0.72300204, -0.6568229 , 1.09582681, 1.38190839],
[-0.26773485, -0.42845732, -0.0394649 , 0.18930252],
[-0.39157696, 2.5402952 , -1.28828579, -1.26832688],
[-1.25847174, -0.20009175, -1.28828579, -1.13581511],
[ 0.59915993, -0.42845732, 1.09582681, 0.85186133],
[-1.75384019, 0.25663941, -1.34505037, -1.26832688],
[-0.51541907, 1.85519847, -1.11799203, -1.00330335],
[-1.01078752, 0.71337057, -1.17475662, -1.00330335],
[ 1.09452838, -0.20009175, 0.7552393 , 0.71934957],
[-0.51541907, 1.85519847, -1.34505037, -1.00330335],
[ 2.33294949, -0.6568229 , 1.72023726, 1.11688486],
[-0.26773485, -0.88518848, 0.30112261, 0.18930252],
[ 1.21837049, -0.20009175, 1.03906223, 1.24939662],
[-0.39157696, 0.94173615, -1.34505037, -1.26832688],
[-1.25847174, 0.71337057, -1.00446286, -1.26832688],
[-0.51541907, 0.71337057, -1.11799203, -1.26832688],
[ 2.33294949, 1.62683289, 1.72023726, 1.38190839],
[ 1.3422126 , 0.02827383, 0.98229764, 1.24939662],
[-0.26773485, -1.34191964, 0.13082885, -0.07572101],
[-0.88694541, 0.71337057, -1.2315212 , -1.26832688],
[-0.88694541, 1.62683289, -1.17475662, -1.26832688],
[ 0.35147571, -0.42845732, 0.58494554, 0.32181428],
[-0.02005063, 2.08356405, -1.40181496, -1.26832688],
[-1.01078752, -2.48374754, -0.09622949, -0.20823277],
[ 0.72300204, 0.25663941, 0.47141637, 0.45432605],
[ 0.35147571, -0.20009175, 0.52818096, 0.32181428],
[ 0.10379148, 0.25663941, 0.64171013, 0.85186133],
[ 0.2276336 , -2.02701638, 0.18759344, -0.20823277],
[ 1.96142316, -0.6568229 , 1.37964974, 0.9843731 ]])
[15]
X_test_standard = standardScaler.transform(X_test)
[16]
X_test_standard
array([[-0.26773485, -0.20009175, 0.47141637, 0.45432605],
[-0.02005063, -0.6568229 , 0.81200388, 1.64693191],
[-1.01078752, -1.7986508 , -0.20975866, -0.20823277],
[-0.02005063, -0.88518848, 0.81200388, 0.9843731 ],
[-1.50615597, 0.02827383, -1.2315212 , -1.26832688],
[-0.39157696, -1.34191964, 0.18759344, 0.18930252],
[-0.14389274, -0.6568229 , 0.47141637, 0.18930252],
[ 0.84684415, -0.20009175, 0.86876847, 1.11688486],
[ 0.59915993, -1.7986508 , 0.41465178, 0.18930252],
[-0.39157696, -1.11355406, 0.41465178, 0.05679076],
[ 1.09452838, 0.02827383, 0.41465178, 0.32181428],
[-1.62999808, -1.7986508 , -1.34505037, -1.13581511],
[-1.25847174, 0.02827383, -1.17475662, -1.26832688],
[-0.51541907, 0.71337057, -1.2315212 , -1.00330335],
[ 1.71373893, 1.17010173, 1.37964974, 1.77944368],
[-0.02005063, -0.88518848, 0.24435803, -0.20823277],
[-1.50615597, 1.17010173, -1.51534413, -1.26832688],
[ 1.71373893, 0.25663941, 1.32288516, 0.85186133],
[ 1.3422126 , 0.02827383, 0.81200388, 1.51442015],
[ 0.72300204, -0.88518848, 0.92553306, 0.9843731 ],
[ 0.59915993, 0.48500499, 0.58494554, 0.58683781],
[-1.01078752, 0.71337057, -1.2315212 , -1.26832688],
[ 2.33294949, -1.11355406, 1.83376643, 1.51442015],
[-1.01078752, 0.48500499, -1.28828579, -1.26832688],
[ 0.47531782, -0.42845732, 0.3578872 , 0.18930252],
[ 0.10379148, -0.20009175, 0.30112261, 0.45432605],
[-1.01078752, 0.25663941, -1.40181496, -1.26832688],
[-0.39157696, -1.7986508 , 0.18759344, 0.18930252],
[ 0.59915993, 0.48500499, 1.32288516, 1.77944368],
[ 2.33294949, -0.20009175, 1.37964974, 1.51442015],
[-0.88694541, 0.94173615, -1.28828579, -1.26832688],
[-1.13462963, -0.20009175, -1.28828579, -1.26832688],
[-0.14389274, -0.6568229 , 0.24435803, 0.18930252],
[ 0.47531782, 0.71337057, 0.98229764, 1.51442015],
[-0.88694541, -1.34191964, -0.38005242, -0.07572101],
[ 1.46605471, 0.25663941, 0.58494554, 0.32181428],
[ 0.35147571, -1.11355406, 1.09582681, 0.32181428],
[ 2.20910738, -0.20009175, 1.66347267, 1.24939662],
[-0.7631033 , 2.31192962, -1.2315212 , -1.40083864],
[ 0.47531782, -2.02701638, 0.47141637, 0.45432605],
[ 1.83758104, -0.42845732, 1.49317891, 0.85186133],
[ 0.72300204, 0.25663941, 0.92553306, 1.51442015],
[ 0.2276336 , 0.71337057, 0.47141637, 0.58683781],
[-0.7631033 , -0.88518848, 0.13082885, 0.32181428],
[-0.51541907, 1.39846731, -1.2315212 , -1.26832688]])
[17]
from sklearn.neighbors import KNeighborsClassifier
[18]
knn_clf = KNeighborsClassifier(n_neighbors=3)
[19]
knn_clf.fit(X_train,y_train)
KNeighborsClassifier(n_neighbors=3)
[20]
knn_clf.score(X_test_standard,y_test)
0.9777777777777777
[21]
knn_clf.score(X_test,y_test)
0.3333333333333333
4-9 更多有关k近邻算法的思考