【Sklearn】【StandardScaler】fit()、transform()和fit_transform()的区别

1. 各个函数的作用

  1. fit(): 用来计算mean(均值)和std(标准差),以便后面进行数据的标准化
  2. transform(): 根据fit()函数计算的mean和std对数据进行标准化
  3. fit_transform(): 是fit()函数和transform()函数的组合,先进行fit,之后再进行transform(标准化)
  4. skleran官方文档:sklearn.preprocessing.StandardScaler
  5. Tips: Sklearn中数据的标准化,都是通过均值和标准差进行的,是的数据符合高斯分布
    【Sklearn】【StandardScaler】fit()、transform()和fit_transform()的区别_第1张图片
    【Sklearn】【StandardScaler】fit()、transform()和fit_transform()的区别_第2张图片

2. 常见问题

Q: 为什么我们在使用Sklearn的时候,先使用fit_transform()在训练集上,再使用transform()在测试集上?
A:因为我们必须保证,测试集在进行标准化的时候,使用的是统一的缩放参数,即为均值和标准差。所以必须先进行fit,计算出均值和标准差,最后再通过transform进行数据的标准化。

3. 举例

3.1 简单举例

  1. 代码
from sklearn.preprocessing import StandardScaler
data = [[0, 0], [0, 0], [1, 1], [1, 1]]
data1 = [[0, 0], [0, 0], [2, 2], [2, 2]]
scaler = StandardScaler()
scaler.fit(data)
print("拟合后的均值为:", scaler.mean_)
print("拟合后的方差:", scaler.var_)
print("根据均值和方差标准化之后的数据", '\n', scaler.transform(data))
print("根据data计算出来的均值和方差来标准化data1数据", '\n', scaler.transform(data1))

# 对比实验
scaler.fit(data1)
print("拟合后的均值为:", scaler.mean_)
print("拟合后的方差:", scaler.var_)
print("根据均值和方差标准化之后的数据", '\n', scaler.transform(data))
  1. 执行结果
拟合后的均值为: [0.5 0.5]
拟合后的方差: [0.25 0.25]
根据均值和方差标准化之后的数据 
 [[-1. -1.]
 [-1. -1.]
 [ 1.  1.]
 [ 1.  1.]]
根据均值和方差标准化之后的数据 
 [[-1. -1.]
 [-1. -1.]
 [ 3.  3.]
 [ 3.  3.]]

拟合后的均值为: [1. 1.]
拟合后的方差: [1. 1.]
根据均值和方差标准化之后的数据 
 [[-1. -1.]
 [-1. -1.]
 [ 0.  0.]
 [ 0.  0.]]
  1. 由上可知,data 在拟合之后,计算出来的均值和方差分别为,0.5和0.25. 根据其进行标准化data和data1,我们可以看到data被标准化到了[-1, 1]之间,而data1则超出了这一区间,这一从侧面说明了解释了前述问题。而通过对比实验,重新计算均值和方差,再对data1进行标准化,我们获得了[-1, 1]之间的值。

4. Tips

  1. 在sklearn中,非StanderSclaer类下的fit()、transform()、fit_transform()方法的原理同这一是相同的,只不过这些方法需要先进行fit_transform在训练集上,才能用transform在测试集上,否则会报错!
  2. 如果fit_transfrom(trainData)后,使用fit_transform(testData)而不transform(testData),虽然也能归一化,但是两个结果不是在同一个“标准”下的,具有明显差异。(一定要避免这种情况)

5. 参考资料

  1. sklearn.preprocessing.StandardScaler
  2. What and why behind fit_transform() and transform() in scikit-learn!

你可能感兴趣的:(Sklearn,sklearn,python,机器学习)