通俗地讲清楚fit_transform()和transform()的区别

通俗地讲清楚fit_transform()和transform()的区别

网上抄来抄去都是一个意思,

fit_transform是fit和transform的组合。


 我们知道fit(x,y)在新手入门的例子中比较多,但是这里的fit_transform(x)的括号中只有一个参数,这是为什么呢?

fit(x,y)传两个参数的是有监督学习的算法,fit(x)传一个参数的是无监督学习的算法,比如降维、特征提取、标准化


然后解释为什么出来fit_transform()这个东西,下面是重点:

fit和transform没有任何关系,仅仅是数据处理的两个不同环节,之所以出来这么个函数名,仅仅是为了写代码方便,

所以会发现transform()和fit_transform()的运行结果是一样的。


注意:运行结果一模一样不代表这两个函数可以互相替换,绝对不可以!!!

transform函数是一定可以替换为fit_transform函数的

fit_transform函数不能替换为transform函数!!!理由解释如下:

 sklearn里的封装好的各种算法都要fit、然后调用各种API方法,transform只是其中一个API方法,所以当你调用除transform之外的方法,必须要先fit,为了通用的写代码,还是分开写比较好 

也就是说,这个fit相对于transform而言是没有任何意义的,但是相对于整个代码而言,fit是为后续的API函数服务的,所以fit_transform不能改写为transform。



下面的代码用来举例示范,数据集是代码自动从网上下载的,如果把下面的乳腺癌相关的机器学习代码中的fit_transform改为transform,编译器就会报错。(下面给出的是无错误的代码)

[python]  view plain  copy
  1. # coding: utf-8  
  2. # 导入pandas与numpy工具包。  
  3. import pandas as pd  
  4. import numpy as np  
  5.   
  6. # 创建特征列表。  
  7. column_names = ['Sample code number''Clump Thickness''Uniformity of Cell Size''Uniformity of Cell Shape''Marginal Adhesion''Single Epithelial Cell Size''Bare Nuclei''Bland Chromatin''Normal Nucleoli''Mitoses''Class']  
  8. # 使用pandas.read_csv函数从互联网读取指定数据。  
  9. data = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/breast-cancer-wisconsin.data', names = column_names )  
  10.   
  11. # 将?替换为标准缺失值表示。  
  12. data = data.replace(to_replace='?', value=np.nan)  
  13. # 丢弃带有缺失值的数据(只要有一个维度有缺失)。  
  14. data = data.dropna(how='any')  
  15.   
  16. # 输出data的数据量和维度。  
  17. data.shape  
  18.   
  19.   
  20. # In[2]:  
  21.   
  22.   
  23. # 使用sklearn.cross_valiation里的train_test_split模块用于分割数据。  
  24. from sklearn.cross_validation import train_test_split  
  25.   
  26. # 随机采样25%的数据用于测试,剩下的75%用于构建训练集合。  
  27. X_train, X_test, y_train, y_test = train_test_split(data[column_names[1:10]], data[column_names[10]], test_size=0.25, random_state=33)  
  28. # print "data[column_names[10]]",data[column_names[10]]  
  29.   
  30.   
  31. # 查验训练样本的数量和类别分布。  
  32. y_train=pd.Series(y_train)  
  33. y_train.value_counts()  
  34.   
  35. # 查验测试样本的数量和类别分布。  
  36. y_test=pd.Series(y_test)  
  37. y_test.value_counts()  
  38.   
  39.   
  40. # 从sklearn.preprocessing里导入StandardScaler。  
  41. from sklearn.preprocessing import StandardScaler  
  42. # 从sklearn.linear_model里导入LogisticRegression与SGDClassifier。  
  43. from sklearn.linear_model import LogisticRegression  
  44. from sklearn.linear_model import SGDClassifier  
  45.   
  46.   
  47.   
  48.   
  49.   
  50.   
  51.   
  52. #标准化数据,保证每个维度的特征数据方差为1,均值为0。使得预测结果不会被某些维度过大的特征值而主导。  
  53. ss = StandardScaler()  
  54. X_train = ss.fit_transform(X_train)  
  55. X_test = ss.transform(X_test)  
  56.   
  57. #  
  58.   
  59. # 初始化LogisticRegression与SGDClassifier。  
  60. lr = LogisticRegression()  
  61. sgdc = SGDClassifier()  
  62.   
  63. # 调用LogisticRegression中的fit函数/模块用来训练模型参数。  
  64. lr.fit(X_train, y_train)  
  65. # 使用训练好的模型lr对X_test进行预测,结果储存在变量lr_y_predict中。  
  66. lr_y_predict = lr.predict(X_test)  
  67.   
  68. # 调用SGDClassifier中的fit函数/模块用来训练模型参数。  
  69. sgdc.fit(X_train, y_train)  
  70. # 使用训练好的模型sgdc对X_test进行预测,结果储存在变量sgdc_y_predict中。  
  71. sgdc_y_predict = sgdc.predict(X_test)  
  72.   
  73.   
  74.   
  75.   
  76. # 从sklearn.metrics里导入classification_report模块。  
  77. from sklearn.metrics import classification_report  
  78.   
  79. # 使用逻辑斯蒂回归模型自带的评分函数score获得模型在测试集上的准确性结果。  
  80. print ("Accuracy of LR Classifier:", lr.score(X_test, y_test))  
  81. # 利用classification_report模块获得LogisticRegression其他三个指标的结果。  
  82. print (classification_report(y_test, lr_y_predict, target_names=['Benign''Malignant']))  
  83.   
  84.   
  85.   
  86. # 使用随机梯度下降模型自带的评分函数score获得模型在测试集上的准确性结果。  
  87. print 'Accuarcy of SGD Classifier:', sgdc.score(X_test, y_test)  
  88. # 利用classification_report模块获得SGDClassifier其他三个指标的结果。  
  89. print classification_report(y_test, sgdc_y_predict, target_names=['Benign''Malignant'])  

会得到报错信息

AttributeError: 'StandardScaler' object has no attribute 'mean_'

有的版本报错更加直接:

sklearn.exceptions.NotFittedError: This StandardScaler instance is not fitted yet. Call 'fit' with appropriate arguments before using this method.

原因就是因为代码中的fit_transform函数被改为了transform函数。

所以总结:

fit_transform与transform运行结果一致,但是fit与transform无关,只是数据处理的两个环节,fit是为了程序的后续函数transform的调用而服务的,是个前提条件。

以上。

你可能感兴趣的:(通俗地讲清楚fit_transform()和transform()的区别)