我正在学习sklearn
,并且编写了一个类Classifier
进行常见分类。它需要一个method
来确定使用哪个Estimator:
# Classifier
from sklearn.svm import SVC
from sklearn.svm import LinearSVC
from sklearn.linear_model import SGDClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier
class Classifier(object):
def __init__(self, method='LinearSVC', *args, **kwargs):
Estimator = getattr(**xxx**, method, None)
self.Estimator = Estimator
self._model = Estimator(*args, **kwargs)
def fit(self, data, target):
return self._model.fit(data, target)
def predict(self, data):
return self._model.predict(data)
def score(self, X, y, sample_weight=None):
return self._model.score(X, y, sample_weight=None)
def persist_model(self):
pass
def get_model(self):
return self._model
def classification_report(self, expected, predicted):
return metrics.classification_report(expected, predicted)
def confusion_matrix(self, expected, predicted):
return metrics.confusion_matrix(expected, predicted)
我想按名称获取Estimator,但是xxx应该是什么?
还是有更好的方法来做到这一点?
建立一个字典来存储导入的模块?但是这种方法似乎不太好。
最佳答案
在这种情况下,建议直接将类直接用作参数。
您永远不必担心它是字符串:您可以比较LinearSVC is LinearSVC
,并将其与其他内容进行比较。
可以将其视为接受整数作为参数,然后将其转换为字符串以使用它:这有意义吗?您只需要一个字符串即可。
建议的代码:
class Classifier(object):
def __init__(self, model = LinearSVC, *args, **kwargs):
self._model = model(*args, **kwargs)
然后,您可以执行以下操作:
myclf = Classifier(..., estimator = LinearSVC, ...)
isinstance(myclf._model, LinearSVC)
根据评论:
然后,您还可以在开始时初始化dict,例如:
from sklearn.svm import LinearSVC
str_to_model = {'LinearSVC' : LinearSVC}
class Classifier(object):
def __init__(self, model = "LinearSVC", *args, **kwargs):
self._model = str_to_model[model](*args, **kwargs)
与
KeyError
(字符串/模型不存在,并且由于没有定义它们而知道)相比,检查globals
听起来更讨厌!关于python - 按名称获取当前文件中已导入的模块,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/30726173/