Scikit-Learn | 自定义转换器(transformer)

一、什么是transformer

在Scikit-Learn的设计原则里,所有对象的接口一致且简单。

  • 估计器(estimator):在机器学习中,任何基于数据集,可以对一些参数都被称为估计器(比如RandomForest()、LinearRegression())。
  • 转换器(transformer):可以转换数据集中数值的估计器,如处理缺失值的SimpleImputer(),可参考【处理残缺值(Missing value)python-sklearn实现| 三种常用方法 】。
    它的API调用过程:用过transform()方法进行转换,可参考【fit()、transform()、fit_transform() 三者联系与区别】,被转换的数据集作为参数,返回的是经过转换的数据集。

二、构造转换器

  • 目的:自定义清理操作,或属性组合等。

因为Scikit-Learn是Duck Typing,可参考【编程中的Duck Typing】 (而不是 inheritance).因此我们需要做的就是创建一个类(class)并执行三个方法(method)fit()transform()fit_transform()fit()返回self。
通过添加TransformMixin作为基类,可以自动得到fit_transform()方法。通过添加BaseEstimator作为基类(且构造器中避免使用*args**kargs),就能得到get_params()set_params(),这两个方法可以方便地进行超参数自动微调。

以下是一个简单的tranformer 的构造:

from sklearn.base import BaseEstimator, TransformerMixin

class DataFrameSelector(BaseEstimator, TransformerMixin)
	def __init__(self, attribute_names):  # no *args and **kargs
		self.attributes_names = attribute_names
	def fit(self, X, y=None):
		return self  # 不用做其他事情
	def transform(self, X):
		return X[self.attribute_names].values # 返回的是DataFrame形式里指定的列的值

参考资料:
Hands-on Machine Learning with Scikit-Learn and TensorFlow

你可能感兴趣的:(数据挖掘)