我正在学习sklearn,并且编写了一个类Classifier进行常见分类。它需要一个method来确定使用哪个Estimator:

# Classifier
from sklearn.svm import SVC
from sklearn.svm import LinearSVC
from sklearn.linear_model import SGDClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.tree import DecisionTreeClassifier

class Classifier(object):
    def __init__(self, method='LinearSVC', *args, **kwargs):
        Estimator = getattr(**xxx**, method, None)
        self.Estimator = Estimator
        self._model = Estimator(*args, **kwargs)

    def fit(self, data, target):
        return self._model.fit(data, target)

    def predict(self, data):
        return self._model.predict(data)

    def score(self, X, y, sample_weight=None):
        return self._model.score(X, y, sample_weight=None)

    def persist_model(self):
        pass

    def get_model(self):
        return self._model

    def classification_report(self, expected, predicted):
        return metrics.classification_report(expected, predicted)

    def confusion_matrix(self, expected, predicted):
        return metrics.confusion_matrix(expected, predicted)


我想按名称获取Estimator,但是xxx应该是什么?
还是有更好的方法来做到这一点?
建立一个字典来存储导入的模块?但是这种方法似乎不太好。

最佳答案

在这种情况下,建议直接将类直接用作参数。

您永远不必担心它是字符串:您可以比较LinearSVC is LinearSVC,并将其与其他内容进行比较。

可以将其视为接受整数作为参数,然后将其转换为字符串以使用它:这有意义吗?您只需要一个字符串即可。

建议的代码:

class Classifier(object):
    def __init__(self, model = LinearSVC, *args, **kwargs):
        self._model = model(*args, **kwargs)


然后,您可以执行以下操作:

myclf = Classifier(..., estimator = LinearSVC, ...)
isinstance(myclf._model, LinearSVC)


根据评论:

然后,您还可以在开始时初始化dict,例如:

from sklearn.svm import LinearSVC

str_to_model = {'LinearSVC' : LinearSVC}

class Classifier(object):
    def __init__(self, model = "LinearSVC", *args, **kwargs):
        self._model = str_to_model[model](*args, **kwargs)


KeyError(字符串/模型不存在,并且由于没有定义它们而知道)相比,检查globals听起来更讨厌!

关于python - 按名称获取当前文件中已导入的模块,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/30726173/

10-10 03:59