我已经实现了我自己的分类器,现在我想对其运行网格搜索,但出现以下错误:estimator.fit(X_train, y_train, **fit_params)TypeError: fit() takes 2 positional arguments but 3 were given
我关注了 this tutorial 并使用了 this template 提供的 scikit's official documentation 。我的类(class)定义如下:
class MyClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, lr=0.1):
self.lr=lr
def fit(self, X, y):
# Some code
return self
def predict(self, X):
# Some code
return y_pred
def get_params(self, deep=True)
return {'lr'=self.lr}
def set_params(self, **parameters):
for parameter, value in parameters.items():
setattr(self, parameter, value)
return self
我正在尝试网格搜索将其抛出如下:
params = {
'lr': [0.1, 0.5, 0.7]
}
gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)
编辑我
我是这样称呼它的:
gs.fit(['hello world', 'trying','hello world', 'trying', 'hello world', 'trying', 'hello world', 'trying'],
['I', 'Z', 'I', 'Z', 'I', 'Z', 'I', 'Z'])
结束编辑我
该错误是由文件
_fit_and_score
中的 python3.5/site-packages/sklearn/model_selection/_validation.py
方法产生的它用 3 个参数调用
estimator.fit(X_train, y_train, **fit_params)
,但我的估算器只有两个,所以这个错误对我来说很有意义,但我不知道如何解决它......我也尝试向 fit
方法添加一些虚拟参数,但它没有'工作。编辑 II
完整的错误输出:
Traceback (most recent call last):
File "/home/rodrigo/no_version/text_classifier/MyClassifier.py", line 355, in <module>
['I', 'Z', 'I', 'Z', 'I', 'Z', 'I', 'Z'])
File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/model_selection/_search.py", line 639, in fit
cv.split(X, y, groups)))
File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 779, in __call__
while self.dispatch_one_batch(iterator):
File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 625, in dispatch_one_batch
self._dispatch(tasks)
File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 588, in _dispatch
job = self._backend.apply_async(batch, callback=cb)
File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/_parallel_backends.py", line 111, in apply_async
result = ImmediateResult(func)
File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/_parallel_backends.py", line 332, in __init__
self.results = batch()
File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 131, in __call__
return [func(*args, **kwargs) for func, args, kwargs in self.items]
File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 131, in <listcomp>
return [func(*args, **kwargs) for func, args, kwargs in self.items]
File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/model_selection/_validation.py", line 458, in _fit_and_score
estimator.fit(X_train, y_train, **fit_params)
TypeError: fit() takes 2 positional arguments but 3 were given
结束编辑 II
已解决
谢谢大家,我犯了一个愚蠢的错误:有两个不同的函数具有相同的名称(fit),(我使用不同的参数实现了另一个用于自定义目的,一旦我重命名我的“自定义拟合”,它就可以正常工作。)
谢谢和抱歉
最佳答案
以下代码对我有用:
class MyClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, lr=0.1):
# Some code
pass
def fit(self, X, y):
# Some code
pass
def predict(self, X):
# Some code
return X % 3
params = {
'lr': [0.1, 0.5, 0.7]
}
gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)
x = np.arange(30)
y = np.concatenate((np.zeros(10), np.ones(10), np.ones(10) * 2))
gs.fit(x, y)
我能想到的最好的结果是,您正在向
gs.fit
方法中传递一些超出 x
和 y
的内容,或者您的 MyClassifier.fit
方法缺少 self 参数。只有在将 kwarg 传递给
fit_params
方法时才应填充 gs.fit
kwargs,否则它是一个空字典( {}
)并且 **fit_params
不会抛出参数错误。要对此进行测试,请创建一个分类器实例并传递 **{}
。例如:clf = MyClassifier()
clf.fit(x, y, **{})
这不会引发位置参数错误。
因此,除非将某些内容传递给
gs.fit
,例如gs.fit(x, y, some_arg=123)
在我看来,您在 MyClassifier.fit
的定义中缺少位置参数之一。您包含的错误消息似乎支持这一假设,因为它指出 fit() takes 2 positional arguments but 3 were given
。如果您按如下方式定义了 fit ,它将需要 3 个位置参数:def fit(self, X, y): ...
关于python - scikit 学习 : custom classifier compatible with GridSearchCV,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/48211590/