最近在读 Hands-On Machine Learning with Scikit-Learn & TensorFlow 这本书,在学到pipeline的时候,我模仿者写了这样的代码:
[python] view plain copy
但是会报错如下:
[plain] view plain copy
我想,这应该是版本更新引起的问题,果然我在这里找到了答案。以下为引用:
The pipeline is assuming LabelBinarizer's fit_transform
method is defined to take three positional arguments:
def fit_transform(self, x, y)
...rest of the code
while it is defined to take only two:
def fit_transform(self, x):
...rest of the code
所以,解决方法就是,自己写一个根据LabelBinarizer写一个MyLabelBinarizer,可以有三个参数self,X,y=None.
from sklearn.base import TransformerMixin #gives fit_transform method for free
class MyLabelBinarizer(TransformerMixin):
def __init__(self, *args, **kwargs):
self.encoder = LabelBinarizer(*args, **kwargs)
def fit(self, x, y=0):
self.encoder.fit(x)
return self
def transform(self, x, y=0):
return self.encoder.transform(x)