第4章下 最基础的分类算法-k近邻算法 kNN

4-6 网格搜索与k近邻算法中更多超参数

第4章下 最基础的分类算法-k近邻算法 kNN_第1张图片

 Notbook 示例

第4章下 最基础的分类算法-k近邻算法 kNN_第2张图片

 Notbook 源码

[1]
import numpy as np
from sklearn import datasets
[2]
digits = datasets.load_digits()
X = digits.data
y = digits.target
[3]
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X,y,test_size = 0.3 ,random_state=111 )
[4]
from sklearn.neighbors import KNeighborsClassifier

knn_clf = KNeighborsClassifier( n_neighbors = 6 )
knn_clf.fit(X_train,y_train)
knn_clf.score(X_test,y_test)
0.9833333333333333
Grid Search
[5]
param_gid = [
    {
        'weights': ['unifrom'],
        'n_neighbors': [ i for i in range(1,11)]
    },
    {
        'weights': ['distance'],
        'n_neighbors': [ i for i in range(1,11)],
        'p': [ i for i in range(1,6)]
    }
    
]
[6]
knn_clf = KNeighborsClassifier()
[7]
from sklearn.model_selection import GridSearchCV

grid_search = GridSearchCV(knn_clf,param_gid)
[8]
%%time
grid_search.fit(X_train,y_train)
CPU times: total: 2min 15s
Wall time: 2min 18s

F:\anaconda\lib\site-packages\sklearn\model_selection\_validation.py:372: FitFailedWarning: 
50 fits failed out of a total of 300.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
50 fits failed with the following error:
Traceback (most recent call last):
  File "F:\anaconda\lib\site-packages\sklearn\model_selection\_validation.py", line 680, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "F:\anaconda\lib\site-packages\sklearn\neighbors\_classification.py", line 196, in fit
    self.weights = _check_weights(self.weights)
  File "F:\anaconda\lib\site-packages\sklearn\neighbors\_base.py", line 82, in _check_weights
    raise ValueError(
ValueError: weights not recognized: should be 'uniform', 'distance', or a callable function

  warnings.warn(some_fits_failed_message, FitFailedWarning)
F:\anaconda\lib\site-packages\sklearn\model_selection\_search.py:969: UserWarning: One or more of the test scores are non-finite: [       nan        nan        nan        nan        nan        nan
        nan        nan        nan        nan 0.98011446 0.98965724
 0.98965724 0.99204452 0.98966041 0.98090811 0.98965724 0.98965724
 0.99204452 0.98966041 0.98408904 0.98726997 0.98646999 0.98726681
 0.98488585 0.98249542 0.98806678 0.98886359 0.98886359 0.98726997
 0.98249542 0.98647948 0.98966041 0.98726997 0.98488902 0.98249542
 0.98488585 0.9856795  0.98885727 0.9856795  0.98090179 0.98329539
 0.9856795  0.98806362 0.98488269 0.98010181 0.98170176 0.9856795
 0.98726997 0.98487637 0.97692405 0.98408904 0.98329223 0.98647948
 0.98488585 0.97851135 0.98010814 0.98488269 0.98726997 0.98726997]
  warnings.warn(

GridSearchCV(estimator=KNeighborsClassifier(),
             param_grid=[{'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
                          'weights': ['unifrom']},
                         {'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
                          'p': [1, 2, 3, 4, 5], 'weights': ['distance']}])
[9]
grid_search.best_estimator_
KNeighborsClassifier(n_neighbors=1, p=4, weights='distance')
[10]
grid_search.best_score_
0.9920445203313729
[11]
grid_search.best_params_
{'n_neighbors': 1, 'p': 4, 'weights': 'distance'}
若 随机数为111,则 grid_search.best_score_ = 0.9920445203313729, grid_search.best_params_ = {'n_neighbors': 1, 'p': 4, 'weights': 'distance'}

[12]
knn_clf = grid_search.best_estimator_
[13]
knn_clf.predict(X_test)
array([7, 1, 2, 9, 5, 8, 6, 4, 8, 3, 8, 4, 4, 5, 7, 1, 6, 1, 0, 6, 6, 8,
       8, 0, 8, 4, 5, 8, 0, 0, 3, 3, 5, 2, 1, 4, 8, 6, 7, 3, 3, 9, 6, 0,
       4, 9, 7, 3, 8, 7, 4, 3, 5, 0, 3, 1, 7, 6, 5, 7, 6, 0, 9, 7, 7, 8,
       2, 8, 6, 6, 1, 1, 2, 6, 4, 6, 4, 8, 6, 9, 8, 1, 3, 4, 4, 2, 0, 7,
       6, 0, 8, 2, 0, 5, 8, 5, 3, 3, 7, 4, 7, 3, 4, 2, 4, 9, 1, 8, 5, 1,
       2, 7, 0, 2, 8, 9, 7, 5, 7, 7, 8, 8, 9, 2, 3, 9, 7, 7, 8, 2, 5, 3,
       2, 4, 0, 1, 4, 8, 7, 9, 6, 8, 1, 5, 2, 6, 1, 4, 1, 6, 5, 3, 4, 2,
       2, 7, 0, 7, 1, 5, 4, 6, 1, 7, 4, 9, 6, 8, 5, 8, 4, 3, 3, 2, 5, 6,
       7, 9, 0, 2, 0, 5, 4, 8, 0, 8, 6, 9, 7, 3, 1, 9, 4, 2, 7, 9, 4, 0,
       5, 2, 8, 2, 9, 1, 8, 5, 4, 5, 7, 7, 5, 5, 0, 1, 4, 4, 6, 5, 7, 6,
       0, 6, 7, 1, 9, 0, 6, 1, 2, 9, 1, 5, 3, 0, 2, 1, 0, 9, 3, 4, 1, 0,
       9, 9, 2, 0, 5, 3, 6, 5, 5, 3, 9, 1, 2, 8, 7, 4, 9, 8, 8, 1, 3, 1,
       6, 3, 0, 7, 2, 4, 7, 2, 5, 0, 6, 4, 7, 4, 1, 0, 3, 1, 8, 0, 5, 6,
       9, 5, 5, 0, 6, 0, 5, 2, 9, 7, 2, 9, 1, 0, 3, 5, 8, 8, 0, 4, 3, 4,
       6, 1, 6, 1, 7, 3, 3, 2, 3, 6, 7, 1, 0, 1, 9, 6, 6, 6, 8, 2, 3, 5,
       9, 4, 4, 5, 3, 9, 7, 1, 3, 0, 0, 8, 6, 9, 7, 9, 6, 4, 2, 7, 2, 6,
       5, 4, 1, 7, 9, 0, 1, 1, 7, 5, 3, 3, 7, 4, 9, 0, 8, 6, 0, 9, 1, 9,
       7, 8, 8, 8, 6, 2, 1, 3, 0, 2, 3, 6, 8, 1, 6, 1, 3, 9, 6, 2, 5, 2,
       9, 7, 7, 6, 5, 8, 0, 1, 8, 6, 3, 5, 0, 4, 3, 9, 9, 3, 4, 3, 7, 9,
       2, 3, 5, 3, 9, 3, 1, 4, 7, 7, 1, 7, 4, 3, 0, 8, 0, 9, 6, 3, 9, 8,
       3, 9, 9, 9, 4, 1, 6, 7, 7, 2, 0, 1, 0, 7, 5, 7, 6, 1, 5, 0, 6, 9,
       5, 1, 2, 1, 7, 5, 2, 1, 8, 1, 8, 8, 2, 8, 6, 8, 7, 0, 9, 9, 6, 2,
       0, 9, 6, 3, 4, 3, 0, 8, 5, 4, 8, 6, 4, 5, 2, 5, 6, 1, 0, 5, 7, 0,
       9, 5, 3, 2, 9, 3, 0, 6, 4, 8, 3, 2, 3, 6, 6, 8, 1, 9, 4, 3, 1, 1,
       4, 5, 4, 3, 7, 5, 3, 3, 7, 8, 1, 0])
[14]
knn_clf.score(X_test,y_test)
0.9907407407407407
[15]
%%time
grid_search = GridSearchCV(knn_clf,param_gid,n_jobs= 4, verbose = 2)
grid_search.fit(X_train,y_train)
# 创建多个分类器来比较,可以并行处理,n_jobs 为分配核的数量,默认为单核 1 .-1为全核。
# verbose,及时输出一些信息,值越大越详细
Fitting 5 folds for each of 60 candidates, totalling 300 fits
CPU times: total: 484 ms
Wall time: 1min 28s

F:\anaconda\lib\site-packages\sklearn\model_selection\_validation.py:372: FitFailedWarning: 
50 fits failed out of a total of 300.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
50 fits failed with the following error:
Traceback (most recent call last):
  File "F:\anaconda\lib\site-packages\sklearn\model_selection\_validation.py", line 680, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "F:\anaconda\lib\site-packages\sklearn\neighbors\_classification.py", line 196, in fit
    self.weights = _check_weights(self.weights)
  File "F:\anaconda\lib\site-packages\sklearn\neighbors\_base.py", line 82, in _check_weights
    raise ValueError(
ValueError: weights not recognized: should be 'uniform', 'distance', or a callable function

  warnings.warn(some_fits_failed_message, FitFailedWarning)
F:\anaconda\lib\site-packages\sklearn\model_selection\_search.py:969: UserWarning: One or more of the test scores are non-finite: [       nan        nan        nan        nan        nan        nan
        nan        nan        nan        nan 0.98011446 0.98965724
 0.98965724 0.99204452 0.98966041 0.98090811 0.98965724 0.98965724
 0.99204452 0.98966041 0.98408904 0.98726997 0.98646999 0.98726681
 0.98488585 0.98249542 0.98806678 0.98886359 0.98886359 0.98726997
 0.98249542 0.98647948 0.98966041 0.98726997 0.98488902 0.98249542
 0.98488585 0.9856795  0.98885727 0.9856795  0.98090179 0.98329539
 0.9856795  0.98806362 0.98488269 0.98010181 0.98170176 0.9856795
 0.98726997 0.98487637 0.97692405 0.98408904 0.98329223 0.98647948
 0.98488585 0.97851135 0.98010814 0.98488269 0.98726997 0.98726997]
  warnings.warn(

GridSearchCV(estimator=KNeighborsClassifier(n_neighbors=1, p=4,
                                            weights='distance'),
             n_jobs=4,
             param_grid=[{'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
                          'weights': ['unifrom']},
                         {'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
                          'p': [1, 2, 3, 4, 5], 'weights': ['distance']}],
             verbose=2)
 
  

4-7 数据归一化

第4章下 最基础的分类算法-k近邻算法 kNN_第3张图片

第4章下 最基础的分类算法-k近邻算法 kNN_第4张图片

第4章下 最基础的分类算法-k近邻算法 kNN_第5张图片

第4章下 最基础的分类算法-k近邻算法 kNN_第6张图片

 Notbook 示例

第4章下 最基础的分类算法-k近邻算法 kNN_第7张图片

Notbook 源码

数据归一化处理
[1]
import numpy as np
import matplotlib.pyplot as plt
最值归一化 Normalization
[2]
x = np.random.randint(0,100, size = 100)
[3]
x
array([ 2, 58, 55, 40, 68,  7, 72,  4, 50, 89, 96, 19, 71,  6, 41, 40, 63,
        4,  2, 26, 79, 62, 62, 92, 49, 46, 75, 47, 44, 91, 40, 67, 38, 14,
       13,  0, 93, 15, 20, 50, 38, 31, 41, 22, 36, 85, 64, 87, 98, 65, 31,
       10, 46, 22, 86, 24, 68, 25, 11, 31, 22, 87, 84, 18, 58, 87,  4, 15,
       64, 92, 70, 59, 74, 81, 47, 33,  1, 93,  6, 37, 62, 17, 58, 56, 98,
       53,  2, 70, 36, 17, 21, 66,  1, 79, 60, 71, 89, 71, 61, 30])
[5]
( x - np.min(x)) / (np.max(x) - np.min(x))
array([0.02040816, 0.59183673, 0.56122449, 0.40816327, 0.69387755,
       0.07142857, 0.73469388, 0.04081633, 0.51020408, 0.90816327,
       0.97959184, 0.19387755, 0.7244898 , 0.06122449, 0.41836735,
       0.40816327, 0.64285714, 0.04081633, 0.02040816, 0.26530612,
       0.80612245, 0.63265306, 0.63265306, 0.93877551, 0.5       ,
       0.46938776, 0.76530612, 0.47959184, 0.44897959, 0.92857143,
       0.40816327, 0.68367347, 0.3877551 , 0.14285714, 0.13265306,
       0.        , 0.94897959, 0.15306122, 0.20408163, 0.51020408,
       0.3877551 , 0.31632653, 0.41836735, 0.2244898 , 0.36734694,
       0.86734694, 0.65306122, 0.8877551 , 1.        , 0.66326531,
       0.31632653, 0.10204082, 0.46938776, 0.2244898 , 0.87755102,
       0.24489796, 0.69387755, 0.25510204, 0.1122449 , 0.31632653,
       0.2244898 , 0.8877551 , 0.85714286, 0.18367347, 0.59183673,
       0.8877551 , 0.04081633, 0.15306122, 0.65306122, 0.93877551,
       0.71428571, 0.60204082, 0.75510204, 0.82653061, 0.47959184,
       0.33673469, 0.01020408, 0.94897959, 0.06122449, 0.37755102,
       0.63265306, 0.17346939, 0.59183673, 0.57142857, 1.        ,
       0.54081633, 0.02040816, 0.71428571, 0.36734694, 0.17346939,
       0.21428571, 0.67346939, 0.01020408, 0.80612245, 0.6122449 ,
       0.7244898 , 0.90816327, 0.7244898 , 0.62244898, 0.30612245])
[6]
X = np.random.randint(0,100,(50,2))
[7]
X[:10,:]
array([[19, 14],
       [23, 82],
       [ 4, 17],
       [44, 58],
       [23, 91],
       [46, 17],
       [34, 25],
       [29, 39],
       [69, 61],
       [70, 25]])
[9]
X = np.array(X,dtype = float)
[10]
X[:10,:]
array([[19., 14.],
       [23., 82.],
       [ 4., 17.],
       [44., 58.],
       [23., 91.],
       [46., 17.],
       [34., 25.],
       [29., 39.],
       [69., 61.],
       [70., 25.]])
[18]
X[:,0] = (X[:,0] - np.min(X[:,0])) / ( np.max(X[:,0]) - np.min(X[:,0]))
[19]
X[:,1] = (X[:,1] - np.min(X[:,1])) / ( np.max(X[:,1]) - np.min(X[:,1]))
[20]
X[:10,:]
array([[0.19191919, 0.11458333],
       [0.23232323, 0.82291667],
       [0.04040404, 0.14583333],
       [0.44444444, 0.57291667],
       [0.23232323, 0.91666667],
       [0.46464646, 0.14583333],
       [0.34343434, 0.22916667],
       [0.29292929, 0.375     ],
       [0.6969697 , 0.60416667],
       [0.70707071, 0.22916667]])
[21]
plt.scatter(X[:,0],X[:,1])


[22]
np.mean(X[:,0])
0.4503030303030303
[26]
np.std(X[:,0])
0.32224653972392703
[24]
np.mean(X[:,1])
0.4503030303030303
[27]
np.std(X[:,1])
0.3004160887650475
均值方差归一化
[28]
X2  =  np.random.randint(0,100,(50,2))
[29]
X2 = np.array(X2,dtype = float)
[36]
X2[:,0] = (X2[:,0] - np.mean(X2[:,0])) /  np.std(X2[:,0])
[37]
X2[:,1] = (X2[:,1] - np.mean(X2[:,1])) /  np.std(X2[:,1])
[38]
plt.scatter(X2[:,0],X2[:,1])


[39]
np.mean(X2[:,0])
0.0
[40]
np.std(X2[:,0])
1.0
[41]
np.mean(X2[:,1])
-4.4408920985006264e-17
[42]
np.std(X2[:,1])
1.0

4-8 scikit-learn中的Scaler

第4章下 最基础的分类算法-k近邻算法 kNN_第8张图片

第4章下 最基础的分类算法-k近邻算法 kNN_第9张图片

第4章下 最基础的分类算法-k近邻算法 kNN_第10张图片

第4章下 最基础的分类算法-k近邻算法 kNN_第11张图片

Notbook 示例

第4章下 最基础的分类算法-k近邻算法 kNN_第12张图片

notbook 源码

[1]
import numpy as np
from sklearn import datasets
[2]
iris = datasets.load_iris()
[3]
X = iris.data
y = iris.target
[4]
X[:10,:]
array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [4.6, 3.4, 1.4, 0.3],
       [5. , 3.4, 1.5, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1]])
[5]
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size = 0.3 ,random_state=666 )
scikit-learn中的StandardScaler
[6]
from sklearn.preprocessing import StandardScaler
[7]
standardScaler = StandardScaler()
[8]
standardScaler.fit(X_train)
StandardScaler()
[9]
standardScaler.mean_
array([5.81619048, 3.08761905, 3.66952381, 1.15714286])
[10]
standardScaler.scale_ #   .std 旧形式已经弃用 , scale表示数据分布范围
array([0.80747977, 0.43789436, 1.76166176, 0.75464998])
[11]
standardScaler.transform(X_train)
array([[-0.63926119,  1.39846731, -1.2315212 , -1.26832688],
       [-1.01078752,  0.94173615, -1.17475662, -0.73827982],
       [-1.75384019, -0.42845732, -1.28828579, -1.26832688],
       [-0.02005063, -0.88518848,  0.13082885,  0.05679076],
       [-0.7631033 ,  0.71337057, -1.28828579, -1.26832688],
       [-1.50615597,  0.71337057, -1.28828579, -1.13581511],
       [ 0.84684415,  0.25663941,  0.81200388,  1.11688486],
       [-0.14389274, -0.42845732,  0.30112261,  0.18930252],
       [ 0.97068626, -0.20009175,  0.41465178,  0.32181428],
       [ 0.2276336 , -0.42845732,  0.47141637,  0.45432605],
       [-1.38231385,  0.25663941, -1.17475662, -1.26832688],
       [-1.13462963,  1.17010173, -1.28828579, -1.40083864],
       [ 1.09452838,  0.02827383,  1.09582681,  1.64693191],
       [ 0.59915993, -0.88518848,  0.69847471,  0.85186133],
       [ 0.35147571, -0.6568229 ,  0.58494554,  0.05679076],
       [ 0.47531782, -0.6568229 ,  0.64171013,  0.85186133],
       [-0.14389274,  2.99702636, -1.2315212 , -1.00330335],
       [ 0.59915993, -1.34191964,  0.69847471,  0.45432605],
       [ 0.72300204, -0.42845732,  0.3578872 ,  0.18930252],
       [-0.88694541,  1.62683289, -1.00446286, -1.00330335],
       [ 1.21837049, -0.6568229 ,  0.64171013,  0.32181428],
       [-0.88694541,  0.94173615, -1.28828579, -1.13581511],
       [-1.8776823 , -0.20009175, -1.45857955, -1.40083864],
       [ 0.10379148, -0.20009175,  0.81200388,  0.85186133],
       [ 0.72300204, -0.6568229 ,  1.09582681,  1.24939662],
       [-0.26773485, -0.6568229 ,  0.69847471,  1.11688486],
       [-0.39157696, -1.57028522,  0.01729968, -0.20823277],
       [ 1.3422126 ,  0.02827383,  0.69847471,  0.45432605],
       [ 0.59915993,  0.71337057,  1.09582681,  1.64693191],
       [ 0.84684415, -0.20009175,  1.20935598,  1.38190839],
       [-0.14389274,  1.62683289, -1.11799203, -1.13581511],
       [ 0.97068626, -0.42845732,  0.52818096,  0.18930252],
       [ 1.09452838,  0.48500499,  1.1525914 ,  1.77944368],
       [-1.25847174, -0.20009175, -1.28828579, -1.40083864],
       [-1.01078752,  1.17010173, -1.28828579, -1.26832688],
       [ 0.2276336 , -0.20009175,  0.64171013,  0.85186133],
       [-1.01078752, -0.20009175, -1.17475662, -1.26832688],
       [ 0.35147571, -0.20009175,  0.69847471,  0.85186133],
       [ 0.72300204,  0.02827383,  1.03906223,  0.85186133],
       [-0.88694541,  1.39846731, -1.2315212 , -1.00330335],
       [-0.14389274, -0.20009175,  0.30112261,  0.05679076],
       [-1.01078752,  0.94173615, -1.34505037, -1.13581511],
       [-0.88694541,  1.62683289, -1.2315212 , -1.13581511],
       [-1.50615597,  0.25663941, -1.28828579, -1.26832688],
       [-0.51541907, -0.20009175,  0.47141637,  0.45432605],
       [ 0.84684415, -0.6568229 ,  0.52818096,  0.45432605],
       [ 0.35147571, -0.6568229 ,  0.18759344,  0.18930252],
       [-1.25847174,  0.71337057, -1.17475662, -1.26832688],
       [-0.88694541,  0.48500499, -1.11799203, -0.87079159],
       [-0.02005063, -0.88518848,  0.81200388,  0.9843731 ],
       [-0.26773485, -0.20009175,  0.24435803,  0.18930252],
       [ 0.59915993, -0.6568229 ,  0.81200388,  0.45432605],
       [ 1.09452838,  0.48500499,  1.1525914 ,  1.24939662],
       [ 1.71373893, -0.20009175,  1.20935598,  0.58683781],
       [ 1.09452838, -0.20009175,  0.86876847,  1.51442015],
       [-1.13462963,  0.02827383, -1.2315212 , -1.40083864],
       [-1.13462963, -1.34191964,  0.47141637,  0.71934957],
       [-0.14389274, -1.34191964,  0.7552393 ,  1.11688486],
       [-1.13462963, -1.57028522, -0.20975866, -0.20823277],
       [-0.39157696, -1.57028522,  0.07406427, -0.07572101],
       [ 1.09452838, -1.34191964,  1.20935598,  0.85186133],
       [ 0.84684415, -0.20009175,  1.03906223,  0.85186133],
       [-0.14389274, -1.11355406, -0.09622949, -0.20823277],
       [ 0.2276336 , -2.02701638,  0.7552393 ,  0.45432605],
       [ 1.09452838,  0.02827383,  0.58494554,  0.45432605],
       [-1.13462963,  0.02827383, -1.2315212 , -1.26832688],
       [ 0.59915993, -1.34191964,  0.7552393 ,  0.9843731 ],
       [-1.38231385,  0.25663941, -1.34505037, -1.26832688],
       [ 0.2276336 , -0.88518848,  0.81200388,  0.58683781],
       [-0.02005063, -1.11355406,  0.18759344,  0.05679076],
       [ 1.3422126 ,  0.25663941,  1.1525914 ,  1.51442015],
       [-1.75384019, -0.20009175, -1.34505037, -1.26832688],
       [ 1.58989682, -0.20009175,  1.26612057,  1.24939662],
       [ 1.21837049,  0.25663941,  1.26612057,  1.51442015],
       [-0.7631033 ,  0.94173615, -1.2315212 , -1.26832688],
       [ 2.58063371,  1.62683289,  1.5499435 ,  1.11688486],
       [ 0.72300204, -0.6568229 ,  1.09582681,  1.38190839],
       [-0.26773485, -0.42845732, -0.0394649 ,  0.18930252],
       [-0.39157696,  2.5402952 , -1.28828579, -1.26832688],
       [-1.25847174, -0.20009175, -1.28828579, -1.13581511],
       [ 0.59915993, -0.42845732,  1.09582681,  0.85186133],
       [-1.75384019,  0.25663941, -1.34505037, -1.26832688],
       [-0.51541907,  1.85519847, -1.11799203, -1.00330335],
       [-1.01078752,  0.71337057, -1.17475662, -1.00330335],
       [ 1.09452838, -0.20009175,  0.7552393 ,  0.71934957],
       [-0.51541907,  1.85519847, -1.34505037, -1.00330335],
       [ 2.33294949, -0.6568229 ,  1.72023726,  1.11688486],
       [-0.26773485, -0.88518848,  0.30112261,  0.18930252],
       [ 1.21837049, -0.20009175,  1.03906223,  1.24939662],
       [-0.39157696,  0.94173615, -1.34505037, -1.26832688],
       [-1.25847174,  0.71337057, -1.00446286, -1.26832688],
       [-0.51541907,  0.71337057, -1.11799203, -1.26832688],
       [ 2.33294949,  1.62683289,  1.72023726,  1.38190839],
       [ 1.3422126 ,  0.02827383,  0.98229764,  1.24939662],
       [-0.26773485, -1.34191964,  0.13082885, -0.07572101],
       [-0.88694541,  0.71337057, -1.2315212 , -1.26832688],
       [-0.88694541,  1.62683289, -1.17475662, -1.26832688],
       [ 0.35147571, -0.42845732,  0.58494554,  0.32181428],
       [-0.02005063,  2.08356405, -1.40181496, -1.26832688],
       [-1.01078752, -2.48374754, -0.09622949, -0.20823277],
       [ 0.72300204,  0.25663941,  0.47141637,  0.45432605],
       [ 0.35147571, -0.20009175,  0.52818096,  0.32181428],
       [ 0.10379148,  0.25663941,  0.64171013,  0.85186133],
       [ 0.2276336 , -2.02701638,  0.18759344, -0.20823277],
       [ 1.96142316, -0.6568229 ,  1.37964974,  0.9843731 ]])
[12]
X_train
array([[5.3, 3.7, 1.5, 0.2],
       [5. , 3.5, 1.6, 0.6],
       [4.4, 2.9, 1.4, 0.2],
       [5.8, 2.7, 3.9, 1.2],
       [5.2, 3.4, 1.4, 0.2],
       [4.6, 3.4, 1.4, 0.3],
       [6.5, 3.2, 5.1, 2. ],
       [5.7, 2.9, 4.2, 1.3],
       [6.6, 3. , 4.4, 1.4],
       [6. , 2.9, 4.5, 1.5],
       [4.7, 3.2, 1.6, 0.2],
       [4.9, 3.6, 1.4, 0.1],
       [6.7, 3.1, 5.6, 2.4],
       [6.3, 2.7, 4.9, 1.8],
       [6.1, 2.8, 4.7, 1.2],
       [6.2, 2.8, 4.8, 1.8],
       [5.7, 4.4, 1.5, 0.4],
       [6.3, 2.5, 4.9, 1.5],
       [6.4, 2.9, 4.3, 1.3],
       [5.1, 3.8, 1.9, 0.4],
       [6.8, 2.8, 4.8, 1.4],
       [5.1, 3.5, 1.4, 0.3],
       [4.3, 3. , 1.1, 0.1],
       [5.9, 3. , 5.1, 1.8],
       [6.4, 2.8, 5.6, 2.1],
       [5.6, 2.8, 4.9, 2. ],
       [5.5, 2.4, 3.7, 1. ],
       [6.9, 3.1, 4.9, 1.5],
       [6.3, 3.4, 5.6, 2.4],
       [6.5, 3. , 5.8, 2.2],
       [5.7, 3.8, 1.7, 0.3],
       [6.6, 2.9, 4.6, 1.3],
       [6.7, 3.3, 5.7, 2.5],
       [4.8, 3. , 1.4, 0.1],
       [5. , 3.6, 1.4, 0.2],
       [6. , 3. , 4.8, 1.8],
       [5. , 3. , 1.6, 0.2],
       [6.1, 3. , 4.9, 1.8],
       [6.4, 3.1, 5.5, 1.8],
       [5.1, 3.7, 1.5, 0.4],
       [5.7, 3. , 4.2, 1.2],
       [5. , 3.5, 1.3, 0.3],
       [5.1, 3.8, 1.5, 0.3],
       [4.6, 3.2, 1.4, 0.2],
       [5.4, 3. , 4.5, 1.5],
       [6.5, 2.8, 4.6, 1.5],
       [6.1, 2.8, 4. , 1.3],
       [4.8, 3.4, 1.6, 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [5.8, 2.7, 5.1, 1.9],
       [5.6, 3. , 4.1, 1.3],
       [6.3, 2.8, 5.1, 1.5],
       [6.7, 3.3, 5.7, 2.1],
       [7.2, 3. , 5.8, 1.6],
       [6.7, 3. , 5.2, 2.3],
       [4.9, 3.1, 1.5, 0.1],
       [4.9, 2.5, 4.5, 1.7],
       [5.7, 2.5, 5. , 2. ],
       [4.9, 2.4, 3.3, 1. ],
       [5.5, 2.4, 3.8, 1.1],
       [6.7, 2.5, 5.8, 1.8],
       [6.5, 3. , 5.5, 1.8],
       [5.7, 2.6, 3.5, 1. ],
       [6. , 2.2, 5. , 1.5],
       [6.7, 3.1, 4.7, 1.5],
       [4.9, 3.1, 1.5, 0.2],
       [6.3, 2.5, 5. , 1.9],
       [4.7, 3.2, 1.3, 0.2],
       [6. , 2.7, 5.1, 1.6],
       [5.8, 2.6, 4. , 1.2],
       [6.9, 3.2, 5.7, 2.3],
       [4.4, 3. , 1.3, 0.2],
       [7.1, 3. , 5.9, 2.1],
       [6.8, 3.2, 5.9, 2.3],
       [5.2, 3.5, 1.5, 0.2],
       [7.9, 3.8, 6.4, 2. ],
       [6.4, 2.8, 5.6, 2.2],
       [5.6, 2.9, 3.6, 1.3],
       [5.5, 4.2, 1.4, 0.2],
       [4.8, 3. , 1.4, 0.3],
       [6.3, 2.9, 5.6, 1.8],
       [4.4, 3.2, 1.3, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [5. , 3.4, 1.6, 0.4],
       [6.7, 3. , 5. , 1.7],
       [5.4, 3.9, 1.3, 0.4],
       [7.7, 2.8, 6.7, 2. ],
       [5.6, 2.7, 4.2, 1.3],
       [6.8, 3. , 5.5, 2.1],
       [5.5, 3.5, 1.3, 0.2],
       [4.8, 3.4, 1.9, 0.2],
       [5.4, 3.4, 1.7, 0.2],
       [7.7, 3.8, 6.7, 2.2],
       [6.9, 3.1, 5.4, 2.1],
       [5.6, 2.5, 3.9, 1.1],
       [5.1, 3.4, 1.5, 0.2],
       [5.1, 3.8, 1.6, 0.2],
       [6.1, 2.9, 4.7, 1.4],
       [5.8, 4. , 1.2, 0.2],
       [5. , 2. , 3.5, 1. ],
       [6.4, 3.2, 4.5, 1.5],
       [6.1, 3. , 4.6, 1.4],
       [5.9, 3.2, 4.8, 1.8],
       [6. , 2.2, 4. , 1. ],
       [7.4, 2.8, 6.1, 1.9]])
[13]
X_train = standardScaler.transform(X_train)
[14]
X_train
array([[-0.63926119,  1.39846731, -1.2315212 , -1.26832688],
       [-1.01078752,  0.94173615, -1.17475662, -0.73827982],
       [-1.75384019, -0.42845732, -1.28828579, -1.26832688],
       [-0.02005063, -0.88518848,  0.13082885,  0.05679076],
       [-0.7631033 ,  0.71337057, -1.28828579, -1.26832688],
       [-1.50615597,  0.71337057, -1.28828579, -1.13581511],
       [ 0.84684415,  0.25663941,  0.81200388,  1.11688486],
       [-0.14389274, -0.42845732,  0.30112261,  0.18930252],
       [ 0.97068626, -0.20009175,  0.41465178,  0.32181428],
       [ 0.2276336 , -0.42845732,  0.47141637,  0.45432605],
       [-1.38231385,  0.25663941, -1.17475662, -1.26832688],
       [-1.13462963,  1.17010173, -1.28828579, -1.40083864],
       [ 1.09452838,  0.02827383,  1.09582681,  1.64693191],
       [ 0.59915993, -0.88518848,  0.69847471,  0.85186133],
       [ 0.35147571, -0.6568229 ,  0.58494554,  0.05679076],
       [ 0.47531782, -0.6568229 ,  0.64171013,  0.85186133],
       [-0.14389274,  2.99702636, -1.2315212 , -1.00330335],
       [ 0.59915993, -1.34191964,  0.69847471,  0.45432605],
       [ 0.72300204, -0.42845732,  0.3578872 ,  0.18930252],
       [-0.88694541,  1.62683289, -1.00446286, -1.00330335],
       [ 1.21837049, -0.6568229 ,  0.64171013,  0.32181428],
       [-0.88694541,  0.94173615, -1.28828579, -1.13581511],
       [-1.8776823 , -0.20009175, -1.45857955, -1.40083864],
       [ 0.10379148, -0.20009175,  0.81200388,  0.85186133],
       [ 0.72300204, -0.6568229 ,  1.09582681,  1.24939662],
       [-0.26773485, -0.6568229 ,  0.69847471,  1.11688486],
       [-0.39157696, -1.57028522,  0.01729968, -0.20823277],
       [ 1.3422126 ,  0.02827383,  0.69847471,  0.45432605],
       [ 0.59915993,  0.71337057,  1.09582681,  1.64693191],
       [ 0.84684415, -0.20009175,  1.20935598,  1.38190839],
       [-0.14389274,  1.62683289, -1.11799203, -1.13581511],
       [ 0.97068626, -0.42845732,  0.52818096,  0.18930252],
       [ 1.09452838,  0.48500499,  1.1525914 ,  1.77944368],
       [-1.25847174, -0.20009175, -1.28828579, -1.40083864],
       [-1.01078752,  1.17010173, -1.28828579, -1.26832688],
       [ 0.2276336 , -0.20009175,  0.64171013,  0.85186133],
       [-1.01078752, -0.20009175, -1.17475662, -1.26832688],
       [ 0.35147571, -0.20009175,  0.69847471,  0.85186133],
       [ 0.72300204,  0.02827383,  1.03906223,  0.85186133],
       [-0.88694541,  1.39846731, -1.2315212 , -1.00330335],
       [-0.14389274, -0.20009175,  0.30112261,  0.05679076],
       [-1.01078752,  0.94173615, -1.34505037, -1.13581511],
       [-0.88694541,  1.62683289, -1.2315212 , -1.13581511],
       [-1.50615597,  0.25663941, -1.28828579, -1.26832688],
       [-0.51541907, -0.20009175,  0.47141637,  0.45432605],
       [ 0.84684415, -0.6568229 ,  0.52818096,  0.45432605],
       [ 0.35147571, -0.6568229 ,  0.18759344,  0.18930252],
       [-1.25847174,  0.71337057, -1.17475662, -1.26832688],
       [-0.88694541,  0.48500499, -1.11799203, -0.87079159],
       [-0.02005063, -0.88518848,  0.81200388,  0.9843731 ],
       [-0.26773485, -0.20009175,  0.24435803,  0.18930252],
       [ 0.59915993, -0.6568229 ,  0.81200388,  0.45432605],
       [ 1.09452838,  0.48500499,  1.1525914 ,  1.24939662],
       [ 1.71373893, -0.20009175,  1.20935598,  0.58683781],
       [ 1.09452838, -0.20009175,  0.86876847,  1.51442015],
       [-1.13462963,  0.02827383, -1.2315212 , -1.40083864],
       [-1.13462963, -1.34191964,  0.47141637,  0.71934957],
       [-0.14389274, -1.34191964,  0.7552393 ,  1.11688486],
       [-1.13462963, -1.57028522, -0.20975866, -0.20823277],
       [-0.39157696, -1.57028522,  0.07406427, -0.07572101],
       [ 1.09452838, -1.34191964,  1.20935598,  0.85186133],
       [ 0.84684415, -0.20009175,  1.03906223,  0.85186133],
       [-0.14389274, -1.11355406, -0.09622949, -0.20823277],
       [ 0.2276336 , -2.02701638,  0.7552393 ,  0.45432605],
       [ 1.09452838,  0.02827383,  0.58494554,  0.45432605],
       [-1.13462963,  0.02827383, -1.2315212 , -1.26832688],
       [ 0.59915993, -1.34191964,  0.7552393 ,  0.9843731 ],
       [-1.38231385,  0.25663941, -1.34505037, -1.26832688],
       [ 0.2276336 , -0.88518848,  0.81200388,  0.58683781],
       [-0.02005063, -1.11355406,  0.18759344,  0.05679076],
       [ 1.3422126 ,  0.25663941,  1.1525914 ,  1.51442015],
       [-1.75384019, -0.20009175, -1.34505037, -1.26832688],
       [ 1.58989682, -0.20009175,  1.26612057,  1.24939662],
       [ 1.21837049,  0.25663941,  1.26612057,  1.51442015],
       [-0.7631033 ,  0.94173615, -1.2315212 , -1.26832688],
       [ 2.58063371,  1.62683289,  1.5499435 ,  1.11688486],
       [ 0.72300204, -0.6568229 ,  1.09582681,  1.38190839],
       [-0.26773485, -0.42845732, -0.0394649 ,  0.18930252],
       [-0.39157696,  2.5402952 , -1.28828579, -1.26832688],
       [-1.25847174, -0.20009175, -1.28828579, -1.13581511],
       [ 0.59915993, -0.42845732,  1.09582681,  0.85186133],
       [-1.75384019,  0.25663941, -1.34505037, -1.26832688],
       [-0.51541907,  1.85519847, -1.11799203, -1.00330335],
       [-1.01078752,  0.71337057, -1.17475662, -1.00330335],
       [ 1.09452838, -0.20009175,  0.7552393 ,  0.71934957],
       [-0.51541907,  1.85519847, -1.34505037, -1.00330335],
       [ 2.33294949, -0.6568229 ,  1.72023726,  1.11688486],
       [-0.26773485, -0.88518848,  0.30112261,  0.18930252],
       [ 1.21837049, -0.20009175,  1.03906223,  1.24939662],
       [-0.39157696,  0.94173615, -1.34505037, -1.26832688],
       [-1.25847174,  0.71337057, -1.00446286, -1.26832688],
       [-0.51541907,  0.71337057, -1.11799203, -1.26832688],
       [ 2.33294949,  1.62683289,  1.72023726,  1.38190839],
       [ 1.3422126 ,  0.02827383,  0.98229764,  1.24939662],
       [-0.26773485, -1.34191964,  0.13082885, -0.07572101],
       [-0.88694541,  0.71337057, -1.2315212 , -1.26832688],
       [-0.88694541,  1.62683289, -1.17475662, -1.26832688],
       [ 0.35147571, -0.42845732,  0.58494554,  0.32181428],
       [-0.02005063,  2.08356405, -1.40181496, -1.26832688],
       [-1.01078752, -2.48374754, -0.09622949, -0.20823277],
       [ 0.72300204,  0.25663941,  0.47141637,  0.45432605],
       [ 0.35147571, -0.20009175,  0.52818096,  0.32181428],
       [ 0.10379148,  0.25663941,  0.64171013,  0.85186133],
       [ 0.2276336 , -2.02701638,  0.18759344, -0.20823277],
       [ 1.96142316, -0.6568229 ,  1.37964974,  0.9843731 ]])
[15]
X_test_standard = standardScaler.transform(X_test)
[16]
X_test_standard
array([[-0.26773485, -0.20009175,  0.47141637,  0.45432605],
       [-0.02005063, -0.6568229 ,  0.81200388,  1.64693191],
       [-1.01078752, -1.7986508 , -0.20975866, -0.20823277],
       [-0.02005063, -0.88518848,  0.81200388,  0.9843731 ],
       [-1.50615597,  0.02827383, -1.2315212 , -1.26832688],
       [-0.39157696, -1.34191964,  0.18759344,  0.18930252],
       [-0.14389274, -0.6568229 ,  0.47141637,  0.18930252],
       [ 0.84684415, -0.20009175,  0.86876847,  1.11688486],
       [ 0.59915993, -1.7986508 ,  0.41465178,  0.18930252],
       [-0.39157696, -1.11355406,  0.41465178,  0.05679076],
       [ 1.09452838,  0.02827383,  0.41465178,  0.32181428],
       [-1.62999808, -1.7986508 , -1.34505037, -1.13581511],
       [-1.25847174,  0.02827383, -1.17475662, -1.26832688],
       [-0.51541907,  0.71337057, -1.2315212 , -1.00330335],
       [ 1.71373893,  1.17010173,  1.37964974,  1.77944368],
       [-0.02005063, -0.88518848,  0.24435803, -0.20823277],
       [-1.50615597,  1.17010173, -1.51534413, -1.26832688],
       [ 1.71373893,  0.25663941,  1.32288516,  0.85186133],
       [ 1.3422126 ,  0.02827383,  0.81200388,  1.51442015],
       [ 0.72300204, -0.88518848,  0.92553306,  0.9843731 ],
       [ 0.59915993,  0.48500499,  0.58494554,  0.58683781],
       [-1.01078752,  0.71337057, -1.2315212 , -1.26832688],
       [ 2.33294949, -1.11355406,  1.83376643,  1.51442015],
       [-1.01078752,  0.48500499, -1.28828579, -1.26832688],
       [ 0.47531782, -0.42845732,  0.3578872 ,  0.18930252],
       [ 0.10379148, -0.20009175,  0.30112261,  0.45432605],
       [-1.01078752,  0.25663941, -1.40181496, -1.26832688],
       [-0.39157696, -1.7986508 ,  0.18759344,  0.18930252],
       [ 0.59915993,  0.48500499,  1.32288516,  1.77944368],
       [ 2.33294949, -0.20009175,  1.37964974,  1.51442015],
       [-0.88694541,  0.94173615, -1.28828579, -1.26832688],
       [-1.13462963, -0.20009175, -1.28828579, -1.26832688],
       [-0.14389274, -0.6568229 ,  0.24435803,  0.18930252],
       [ 0.47531782,  0.71337057,  0.98229764,  1.51442015],
       [-0.88694541, -1.34191964, -0.38005242, -0.07572101],
       [ 1.46605471,  0.25663941,  0.58494554,  0.32181428],
       [ 0.35147571, -1.11355406,  1.09582681,  0.32181428],
       [ 2.20910738, -0.20009175,  1.66347267,  1.24939662],
       [-0.7631033 ,  2.31192962, -1.2315212 , -1.40083864],
       [ 0.47531782, -2.02701638,  0.47141637,  0.45432605],
       [ 1.83758104, -0.42845732,  1.49317891,  0.85186133],
       [ 0.72300204,  0.25663941,  0.92553306,  1.51442015],
       [ 0.2276336 ,  0.71337057,  0.47141637,  0.58683781],
       [-0.7631033 , -0.88518848,  0.13082885,  0.32181428],
       [-0.51541907,  1.39846731, -1.2315212 , -1.26832688]])
[17]
from sklearn.neighbors import KNeighborsClassifier
[18]
knn_clf = KNeighborsClassifier(n_neighbors=3)
[19]
knn_clf.fit(X_train,y_train)
KNeighborsClassifier(n_neighbors=3)
[20]
knn_clf.score(X_test_standard,y_test)
0.9777777777777777
[21]
knn_clf.score(X_test,y_test)
0.3333333333333333
 
  

4-9 更多有关k近邻算法的思考

第4章下 最基础的分类算法-k近邻算法 kNN_第13张图片

第4章下 最基础的分类算法-k近邻算法 kNN_第14张图片

第4章下 最基础的分类算法-k近邻算法 kNN_第15张图片

第4章下 最基础的分类算法-k近邻算法 kNN_第16张图片

第4章下 最基础的分类算法-k近邻算法 kNN_第17张图片

你可能感兴趣的:(机器学习笔记,分类,近邻算法,python,人工智能)